From 8408dda9aa172bd4baec79393c923a02a5f8d3e7 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 28 Apr 2017 16:42:14 -0700 Subject: [PATCH 01/21] rebase --- .../apache/spark/unsafe/types/UTF8String.java | 109 ++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 89 +++++-- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 + .../expressions/stringExpressions.scala | 250 ++++++++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 40 ++- .../expressions/StringExpressionsSuite.scala | 72 +++-- .../org/apache/spark/sql/functions.scala | 30 ++- .../spark/sql/StringFunctionsSuite.scala | 4 + 8 files changed, 532 insertions(+), 69 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 43f57672d9544..0509902ee1029 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -510,6 +510,61 @@ public UTF8String trim() { } } + /** + * Removes all specified trim character either from the beginning or the ending of a string + * @param trimChar the trim character + */ + public UTF8String trim(UTF8String trimChar) { + int numTrimBytes = trimChar.numBytes; + if (numTrimBytes == 0) { + return this; + } + int s = 0; + // e is the search index in the input string, starting with the trim character's bytes + // boundary, moving from right to left. + int e = this.numBytes - numTrimBytes; + // skip all the consecutive matching characters from left to right. + while (s < this.numBytes && s == this.find(trimChar, s)) { + s += numTrimBytes; + } + // skip all the consecutive matching character in the right side. + // if the trimming character has more bytes than the input string, 'e' points to the end + // of input string. + // The search index 'e' will be first positioned at the offset from the end of the input + // string by the number of bytes of the trimming character, if a matching is found, continue + // moving left until the string is exhausted or a non-matching character is hit. Every move + // is the number of bytes of the trimming character. When a non-matching character is hit, + // 'e' needs to be positioned back to the last byte of the non-matching character. + // example 1: + // trim character: 数, input string: 头, both character has 3 bytes. e starts + // 0, rfind could not find matching, index 'e' goes back to the last byte of + // no matching position. + // example 2: + // trim character: 数, input string a, 'a' has 1 byte, '数' has 3 bytes, e starts with -2, + // it should return with the input string + // example 3: + // trim character: 数, input string aaa数, 'aaa数' has 6 bytes, '数' has 3 bytes, e start with + // 3, find matching, move 3 bytes to position 0, didn't find matching, the index e goes back + // to the last byte of no matching position. + if (e < 0) { + e = this.numBytes - 1; + } else { + while (e >= 0 && e == this.rfind(trimChar, e)) { + e -= numTrimBytes; + } + if (e >= 0) { + e += numTrimBytes - 1; + } + } + + if (s > e) { + // empty string + return UTF8String.EMPTY_UTF8; + } else { + return copyUTF8String(s, e); + } + } + public UTF8String trimLeft() { int s = 0; // skip all of the space (0x20) in the left side @@ -522,6 +577,28 @@ public UTF8String trimLeft() { } } + /** + * Removes all specified trim character from the beginning of a string + * @param trimChar the trim character + */ + public UTF8String trimLeft(UTF8String trimChar) { + int numTrimBytes = trimChar.numBytes; + if (numTrimBytes == 0) { + return this; + } + int s = 0; + // skip all the consecutive matching character in the left side + while(s < this.numBytes && s == this.find(trimChar, s)) { + s += numTrimBytes; + } + if (s == this.numBytes) { + // empty string + return UTF8String.EMPTY_UTF8; + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side @@ -535,6 +612,38 @@ public UTF8String trimRight() { } } + /** + * Removes all specified trim character from the ending of a string + * @param trimChar the trim character + */ + public UTF8String trimRight(UTF8String trimChar) { + int numTrimBytes = trimChar.numBytes; + if (numTrimBytes == 0) { + return this; + } + int e = this.numBytes - numTrimBytes; + // skip all the consecutive matching character in the right side + // index 'e' points to first no matching byte position in the input string from right side. + // Index 'e' moves the number of bytes of the trimming character first. + if (e < 0) { + e = this.numBytes - 1; + } else { + while (e >= 0 && e == this.rfind(trimChar, e)) { + e -= numTrimBytes; + } + if (e >= 0) { + e += numTrimBytes - 1; + } + } + + if (e < 0) { + // empty string + return UTF8String.EMPTY_UTF8; + } else { + return copyUTF8String(0, e); + } + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index c376371abdf90..b42d330834beb 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,31 +222,21 @@ public void substring() { @Test public void trims() { - assertEquals(fromString("hello"), fromString(" hello ").trim()); - assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); - assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(""))); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight(fromString(""))); - assertEquals(EMPTY_UTF8, fromString(" ").trim()); - assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); - assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(""))); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(""))); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(""))); - assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); - assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); - assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); - - assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); - assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); - assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim(fromString(""))); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(""))); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(""))); - char[] charsLessThan0x20 = new char[10]; - Arrays.fill(charsLessThan0x20, (char)(' ' - 1)); - String stringStartingWithSpace = - new String(charsLessThan0x20) + "hello" + new String(charsLessThan0x20); - assertEquals(fromString(stringStartingWithSpace), fromString(stringStartingWithSpace).trim()); - assertEquals(fromString(stringStartingWithSpace), - fromString(stringStartingWithSpace).trimLeft()); - assertEquals(fromString(stringStartingWithSpace), - fromString(stringStartingWithSpace).trimRight()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim(fromString(""))); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft(fromString(""))); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight(fromString(""))); } @Test @@ -730,4 +720,59 @@ public void testToLong() throws IOException { assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); } } + @Test + public void trimsChar() { + assertEquals(fromString(" hello "), fromString(" hello ").trim(fromString(""))); + assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); + assertEquals(fromString("he"), fromString("ooheooo").trim(fromString("o"))); + assertEquals(fromString(""), fromString("ooooooo").trim(fromString("o"))); + assertEquals(fromString("b"), fromString("b").trim(fromString("o"))); + assertEquals(fromString(" "), fromString(" ooooooo").trim(fromString("o"))); + assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); + assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); + assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); + assertEquals(fromString("hello"), fromString("oohello").trimLeft(fromString("o"))); + assertEquals(fromString(" "), fromString("oooo ").trimLeft(fromString("o"))); + assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimRight(fromString("a"))); + assertEquals(fromString("ab"), fromString("ab").trimRight(fromString("a"))); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight(fromString(" "))); + assertEquals(fromString("oohell"), fromString("oohelloooo").trimRight(fromString("o"))); + assertEquals(fromString(" oohello "), fromString(" oohello ").trimRight(fromString("o"))); + assertEquals(fromString(" oohell"), fromString(" oohelloo").trimRight(fromString("o"))); + assertEquals(fromString(""), fromString("ooooooo").trimRight(fromString("o"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数"), fromString("数").trim(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trim(fromString("数"))); + assertEquals(fromString(""), fromString("数数数数数").trim(fromString("数"))); + assertEquals(fromString("据砖头"), fromString("数数数据砖头数数").trim(fromString("数"))); + assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); + assertEquals(fromString(" 数数数据砖头"), fromString(" 数数数据砖头数数").trim(fromString("数"))); + assertEquals(fromString("a数数数据砖头数数a"), fromString("a数数数据砖头数数a").trim(fromString("数"))); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); + assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); + assertEquals(fromString("据砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("数"))); + assertEquals(fromString(" 数数数据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数"))); + assertEquals(fromString("数数数据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a"))); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); + assertEquals(fromString("数"), fromString("数").trimRight(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trimRight(fromString("数"))); + assertEquals(fromString("头"), fromString("头").trimRight(fromString("数"))); + assertEquals(fromString("头"), fromString("头数数数").trimRight(fromString("数"))); + assertEquals(fromString("数数数据砖头"), fromString("数数数据砖头数数").trimRight(fromString("数"))); + assertEquals(fromString("数数数据砖头数数 "), fromString("数数数据砖头数数 ").trimRight(fromString("数"))); + assertEquals(fromString("aa数数数"), fromString("aa数数数aaa").trimRight(fromString("a"))); + assertEquals(fromString("数数"), fromString("数数").trimRight(fromString("a"))); + assertEquals(fromString("数数aa"), fromString("数数aa").trimRight(fromString("数"))); + } } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 33bc79a92b9e7..96b6cf5a6a0dc 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -580,6 +580,9 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' trimOperator=(BOTH | LEADING | TRAILING) trimChar=namedExpression + FROM namedExpression ')' #functionCall + | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference @@ -748,6 +751,7 @@ nonReserved | UNBOUNDED | WHEN | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP | DIRECTORY + | BOTH | LEADING | TRAILING ; SELECT: 'SELECT'; @@ -861,6 +865,9 @@ COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; IGNORE: 'IGNORE'; +BOTH: 'BOTH'; +LEADING: 'LEADING'; +TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7ab45a6ee8737..3da24078fad3e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,6 +24,7 @@ import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -504,68 +505,267 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } /** - * A function that trim the spaces from both ends for the specified string. + * A function that trim the spaces or a character from both ends for the specified string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", - examples = """ + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimChar FROM str) - Remove the leading and trailing trimChar from `str` + """, + extended = """ + Arguments: + str - a string expression + trimChar - the trim character + BOTH, FROM - these are keyword to specify for trim character from both side of the string + Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'S' FROM 'SSparkSQLS'); + parkSQL """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes { - def convert(v: UTF8String): UTF8String = v.trim() + require(children.size <= 2 && children.nonEmpty, + s"$prettyName requires at least one argument and no more than two.") + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) override def prettyName: String = "trim" + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + if (inputs(0) != null) { + if (children.size == 1) { + return inputs(0).trim() + } else if (inputs(1) != null) { + if (inputs(0).numChars > 1) { + throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than " + + s"1 character.") + } else { + return inputs(1).trim(inputs(0)) + } + } + } + null + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + if (children.size == 2 && + (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { + throw new AnalysisException(s"The trimming parameter should be Literal " + + s"and only one character.") } + + val evals = children.map(_.genCode(ctx)) + val inputs = evals.map { eval => + s"${eval.isNull} ? null : ${eval.value}" + } + val getTrimFunction = if (children.size == 1) { + s"""UTF8String ${ev.value} = ${inputs(0)}.trim();""" + } else { + s"""UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});""".stripMargin + } + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + ${getTrimFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } + """) + } + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } } } /** - * A function that trim the spaces from left end for given string. + * A function that trim the spaces or a character from left end for given string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", - examples = """ + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(LEADING trimChar FROM str) - Remove the leading trimChar from `str` + """, + extended = """ + Arguments: + str - a string expression + trimChar - the trim character + LEADING, FROM - these are keyword to specify for trim character from left side of the string + Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(LEADING 'S' FROM 'SSparkSQLS'); + parkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes { + + require (children.size <= 2 && children.nonEmpty, + "$prettyName requires at least one argument and no more than two.") - def convert(v: UTF8String): UTF8String = v.trimLeft() + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + if (inputs(0) != null) { + if (children.size == 1) { + return inputs(0).trimLeft() + } else if (inputs(1) != null) { + if (inputs(0).numChars > 1) { + throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than" + + s" 1 character.") + } else { + return inputs(1).trimLeft(inputs(0)) + } + } + } + null + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (children.size == 2 && + (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { + throw new AnalysisException(s"The trimming parameter should be Literal " + + s"and only one character.") } + + val evals = children.map(_.genCode(ctx)) + val inputs = evals.map { eval => + s"${eval.isNull} ? null : ${eval.value}" + } + val getTrimLeftFunction = if (children.size == 1) { + s"""UTF8String ${ev.value} = ${inputs(0)}.trimLeft();""" + } else { + s"""UTF8String ${ev.value} = ${inputs(1)}.trimLeft(${inputs(0)});""" + } + + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + ${getTrimLeftFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } + """) + } + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } } } /** - * A function that trim the spaces from right end for given string. + * A function that trim the spaces or a character from right end for given string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", - examples = """ + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimChar from `str` + """, + extended = """ + Arguments: + str - a string expression + trimChar - the trim character + TRAILING, FROM - these are keyword to specify for trim character from right side of the string + Examples: > SELECT _FUNC_(' SparkSQL '); - SparkSQL + SparkSQL + > SELECT _FUNC_(TRAILING 'S' FROM 'SSparkSQLS'); + SSparkSQL """) -case class StringTrimRight(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimRight(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes { + + require (children.size <= 2 && children.nonEmpty, + "$prettyName requires at least one argument and no more than two.") + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType - def convert(v: UTF8String): UTF8String = v.trimRight() + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) override def prettyName: String = "rtrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimRight()") + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + if (inputs(0) != null) { + if (children.size == 1) { + return inputs(0).trimRight() + } else if (inputs(1) != null) { + if (inputs(0).numChars > 1) { + throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than" + + s" 1 character.") + } else { + return inputs(1).trimRight(inputs(0)) + } + } + } + null + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (children.size == 2 && + (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { + throw new AnalysisException(s"The trimming parameter should be Literal " + + s"and only one character.") } + + val evals = children.map(_.genCode(ctx)) + val inputs = evals.map { eval => + s"${eval.isNull} ? null : ${eval.value}" + } + val getTrimRightFunction = if (children.size == 1) { + s"""UTF8String ${ev.value} = ${inputs(0)}.trimRight();""" + } else { + s"""UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});""" + } + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + ${getTrimRightFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } + """) + } + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 891f61698f177..7c19571a3e22d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1181,6 +1181,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { // Create the function call. val name = ctx.qualifiedName.getText + val trimFuncName = Option(ctx.trimOperator).map { + o => visitTrimFuncName(ctx, o)} val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) @@ -1190,7 +1192,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } - val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct) + val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName, trimFuncName), + arguments, isDistinct) // Check if the function is evaluated in a windowed context. ctx.windowSpec match { @@ -1202,6 +1205,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create a name LTRIM for TRIM(Leading), RTRIM for TRIM(Trailing), TRIM for TRIM(BOTH) + */ + private def visitTrimFuncName(ctx: FunctionCallContext, opt: Token): String = { + if (ctx.qualifiedName.getText.toLowerCase != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case SqlBaseParser.BOTH => "trim" + case SqlBaseParser.LEADING => "ltrim" + case SqlBaseParser.TRAILING => "rtrim" + case _ => throw new ParseException(s"Function trim doesn't support " + + s"this ${opt.getType}.", ctx) + } + } + /** * Create a current timestamp/date expression. These are different from regular function because * they do not require the user to specify braces when calling them. @@ -1218,10 +1238,22 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create a function database (optional) and name pair. */ - protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + protected def visitFunctionName( + ctx: QualifiedNameContext, + trimFuncN: Option[String] = None): FunctionIdentifier = { ctx.identifier().asScala.map(_.getText) match { - case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) - case Seq(fn) => FunctionIdentifier(fn, None) + case Seq(db, fn) => + if (fn.equalsIgnoreCase("trim") && trimFuncN.isDefined) { + FunctionIdentifier(trimFuncN.get, Option(db)) + } else { + FunctionIdentifier(fn, Option(db)) + } + case Seq(fn) => + if (fn.equalsIgnoreCase("trim") && trimFuncN.isDefined) { + FunctionIdentifier(trimFuncN.get, None) + } else { + FunctionIdentifier(fn, None) + } case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4f08031153ab0..57b4b73dff46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("concat") { @@ -408,24 +407,67 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("TRIM/LTRIM/RTRIM") { val s = 'a.string.at(0) - checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) - checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) - - checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) - - checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) - checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq(Literal(" aa "))), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal("aa"))), "", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal(" aa"))), " ", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal("aa "))), " ", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal("aabbaaaa"))), "bb", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal("aabbaaaa "))), "bbaaaa ", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq(s)), "abdef", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("abdefa")) + checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("aaabdefaaaa")) + checkEvaluation(StringTrim(Seq("S", s)), "parkSQL", create_row("SSparkSQLS")) + + checkEvaluation(StringTrimLeft(Seq(Literal(" aa "))), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("a", Literal("aa"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("a", Literal("aa "))), " ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("a", Literal("aabbaaaa"))), "bbaaaa", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq(s)), "abdef ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("a", s)), "bdefa", create_row("abdefa")) + checkEvaluation(StringTrimLeft(Seq("a", s)), " aaabdefaaaa", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimLeft(Seq("S", s)), "parkSQLS", create_row("SSparkSQLS")) + checkEvaluation(StringTrimRight(Seq(Literal(" aa "))), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a", Literal("a"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a", Literal("aa"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a", Literal("aabbaaaa"))), "aabb", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq(s)), " abdef", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a", s)), "abdef", create_row("abdefa")) + checkEvaluation(StringTrimRight(Seq("a", s)), " aaabdef", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimRight(Seq("S", s)), "SSparkSQL", create_row("SSparkSQLS")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimRight(Seq("花", Literal("a"))), "a", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a", Literal("花"))), "花", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("花", Literal("花"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq(s)), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimRight(Seq("花", s)), "花花世界", create_row("花花世界花花")) + checkEvaluation(StringTrimRight(Seq("花", s)), "", create_row("花花花花")) + checkEvaluation(StringTrimRight(Seq("花", s)), " 花花世界花花 ", create_row(" 花花世界花花 ")) + checkEvaluation(StringTrimRight(Seq("a", s)), "aa花花世界花花", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimRight(Seq("a", s)), "aa花花世界花花", create_row("aa花花世界花花")) + checkEvaluation(StringTrimLeft(Seq(s)), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(Seq("花", s)), "世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("花", s)), " 花花世界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("花", s)), "a花花世界花花 ", create_row("a花花世界花花 ")) + checkEvaluation(StringTrimLeft(Seq("a", s)), "花花世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(Seq("a", s)), "花花世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrim(Seq(s)), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(Seq("花", s)), "世界", create_row("花花世界花花")) + checkEvaluation(StringTrim(Seq("花", s)), " 花花世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(Seq("花", s)), " 花花世界花花 ", create_row(" 花花世界花花 ")) + checkEvaluation(StringTrim(Seq("a", s)), "花花世界花花", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(Seq("a", s)), "花花世界花花", create_row("aa花花世界花花")) + checkEvaluation(StringTrim(Seq("花", Literal("花"))), "", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("花", Literal("a"))), "a", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("a", Literal("花"))), "花", create_row(" abdef ")) // scalastyle:on - checkEvaluation(StringTrim(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) + checkEvaluation(StringTrim(Seq((Literal("a")), + (Literal.create(null, StringType)))), null) + checkEvaluation(StringTrimLeft(Seq((Literal("a")), + (Literal.create(null, StringType)))), null) + checkEvaluation(StringTrimRight(Seq((Literal("a")), + (Literal.create(null, StringType)))), null) } test("FORMAT") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47324ed9f2fb8..d5b267dbb7710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2331,7 +2331,15 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + def ltrim(e: Column): Column = withExpr {StringTrimLeft(Seq(e.expr))} + + /** + * Trim the specified character from left ends for the specified string column. + * @group string_funcs + * @since 2.0.0 + */ + def ltrim(trimChar: String, e: Column): Column = + withExpr { StringTrimLeft(Seq(Literal(trimChar), e.expr))} /** * Extract a specific group matched by a Java regex, from the specified string column. @@ -2408,7 +2416,15 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } + def rtrim(e: Column): Column = withExpr { StringTrimRight(Seq(e.expr)) } + + /** + * Trim the specified character from right ends for the specified string column. + * @group string_funcs + * @since 2.0.0 + */ + def rtrim(trimChar: String, e: Column): Column = + withExpr { StringTrimRight(Seq(Literal(trimChar), e.expr))} /** * Returns the soundex code for the specified expression. @@ -2475,7 +2491,15 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = withExpr { StringTrim(e.expr) } + def trim(e: Column): Column = withExpr { StringTrim(Seq(e.expr)) } + + /** + * Trim the specified character from both ends for the specified string column. + * @group string_funcs + * @since 2.0.0 + */ + def trim(trimChar: String, e: Column): Column = + withExpr { StringTrim(Seq(Literal(trimChar), e.expr))} /** * Converts a string column to upper case. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index a12efc835691b..2f5e3f03091d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -167,6 +167,10 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df.select(ltrim($"a"), rtrim($"a"), trim($"a")), Row("example ", " example", "example")) + checkAnswer( + df.select(ltrim("e", $"b"), rtrim("e", $"b"), trim("e", $"b")), + Row("xample", "exampl", "xampl")) + checkAnswer( df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), Row("example ", " example", "example")) From 1918880143518bdd025af652b66bd7241068c77d Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Sat, 29 Apr 2017 20:46:13 -0700 Subject: [PATCH 02/21] adjust testcase --- .../spark/unsafe/types/UTF8StringSuite.java | 32 +++++++++++-------- .../spark/sql/StringFunctionsSuite.scala | 4 +-- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index b42d330834beb..eb6a21b430fd5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,21 +222,27 @@ public void substring() { @Test public void trims() { - assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(""))); - assertEquals(fromString("hello "), fromString(" hello ").trimLeft(fromString(""))); - assertEquals(fromString(" hello"), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); - assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(""))); - assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(""))); - assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(""))); + assertEquals(EMPTY_UTF8, fromString(" ").trim()); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); - assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim(fromString(""))); - assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(""))); - assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(""))); - - assertEquals(fromString("数据砖头"), fromString("数据砖头").trim(fromString(""))); - assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft(fromString(""))); - assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight(fromString(""))); + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + char[] charsLessThan0x20 = new char[10]; + Arrays.fill(charsLessThan0x20, (char)(' ' - 1)); + String stringStartingWithSpace = + new String(charsLessThan0x20) + "hello" + new String(charsLessThan0x20); + assertEquals(fromString(stringStartingWithSpace), fromString(stringStartingWithSpace).trim()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimLeft()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimRight()); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 2f5e3f03091d1..02df0ceaab644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -161,14 +161,14 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") + val df = Seq((" example ", "", "example")).toDF("a", "b", "c") checkAnswer( df.select(ltrim($"a"), rtrim($"a"), trim($"a")), Row("example ", " example", "example")) checkAnswer( - df.select(ltrim("e", $"b"), rtrim("e", $"b"), trim("e", $"b")), + df.select(ltrim("e", $"c"), rtrim("e", $"c"), trim("e", $"c")), Row("xample", "exampl", "xampl")) checkAnswer( From bee1abc1b71c79803e4c54daf63aa01f79db6758 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 1 May 2017 00:13:53 -0700 Subject: [PATCH 03/21] address comments --- .../expressions/stringExpressions.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3da24078fad3e..2f18d0adbe61e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -511,18 +511,25 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. _FUNC_(BOTH trimChar FROM str) - Remove the leading and trailing trimChar from `str` + _FUNC_(LEADING trimChar FROM str) - Remove the leading trimChar from `str` + _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimChar from `str` """, extended = """ Arguments: str - a string expression trimChar - the trim character BOTH, FROM - these are keyword to specify for trim character from both side of the string - + LEADING, FROM - these are keyword to specify for trim character from left side of the string + TRAILING, FROM - these are keyword to specify for trim character from right side of the string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL > SELECT _FUNC_(BOTH 'S' FROM 'SSparkSQLS'); parkSQL + > SELECT _FUNC_(LEADING 'S' FROM 'SSparkSQLS'); + parkSQLS + > SELECT _FUNC_(TRAILING 'S' FROM 'SSparkSQLS'); + SSparkSQL """) case class StringTrim(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -598,18 +605,16 @@ case class StringTrim(children: Seq[Expression]) @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. - _FUNC_(LEADING trimChar FROM str) - Remove the leading trimChar from `str` + _FUNC_(trimChar, str) - Remove the leading trimChar from `str` """, extended = """ Arguments: str - a string expression trimChar - the trim character - LEADING, FROM - these are keyword to specify for trim character from left side of the string - Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL - > SELECT _FUNC_(LEADING 'S' FROM 'SSparkSQLS'); + > SELECT _FUNC_('S', 'SSparkSQLS'); parkSQLS """) case class StringTrimLeft(children: Seq[Expression]) @@ -687,18 +692,16 @@ case class StringTrimLeft(children: Seq[Expression]) @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. - _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimChar from `str` + _FUNC_(trimChar, str) - Remove the trailing trimChar from `str` """, extended = """ Arguments: str - a string expression trimChar - the trim character - TRAILING, FROM - these are keyword to specify for trim character from right side of the string - Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL - > SELECT _FUNC_(TRAILING 'S' FROM 'SSparkSQLS'); + > SELECT _FUNC_('S', 'SSparkSQLS'); SSparkSQL """) case class StringTrimRight(children: Seq[Expression]) From ff90849e28b9ffceabab334f24dd87721e45438d Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 1 May 2017 11:25:49 -0700 Subject: [PATCH 04/21] address comment --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2f18d0adbe61e..289326a7debcf 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -604,7 +604,7 @@ case class StringTrim(children: Seq[Expression]) */ @ExpressionDescription( usage = """ - _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(str) - Removes the leading space characters from `str`. _FUNC_(trimChar, str) - Remove the leading trimChar from `str` """, extended = """ @@ -691,7 +691,7 @@ case class StringTrimLeft(children: Seq[Expression]) */ @ExpressionDescription( usage = """ - _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(str) - Removes the trailing space characters from `str`. _FUNC_(trimChar, str) - Remove the trailing trimChar from `str` """, extended = """ From e434b25a069554d5ebfb7e9eab3a4a1ee6e047a0 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 3 May 2017 08:00:56 -0700 Subject: [PATCH 05/21] adjust LTRIM and RTRIM interface --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- .../scala/org/apache/spark/sql/StringFunctionsSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d5b267dbb7710..ff91a547839fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2338,7 +2338,7 @@ object functions { * @group string_funcs * @since 2.0.0 */ - def ltrim(trimChar: String, e: Column): Column = + def ltrim(e: Column, trimChar: String): Column = withExpr { StringTrimLeft(Seq(Literal(trimChar), e.expr))} /** @@ -2423,7 +2423,7 @@ object functions { * @group string_funcs * @since 2.0.0 */ - def rtrim(trimChar: String, e: Column): Column = + def rtrim(e: Column, trimChar: String): Column = withExpr { StringTrimRight(Seq(Literal(trimChar), e.expr))} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 02df0ceaab644..704557d9741ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -168,7 +168,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("example ", " example", "example")) checkAnswer( - df.select(ltrim("e", $"c"), rtrim("e", $"c"), trim("e", $"c")), + df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim("e", $"c")), Row("xample", "exampl", "xampl")) checkAnswer( From fecdd7b425dd3f37bea608a2e1f26f49c123c943 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 12 May 2017 23:47:06 -0700 Subject: [PATCH 06/21] adding trim charString support --- .../apache/spark/unsafe/types/UTF8String.java | 173 ++++++++++-------- .../spark/unsafe/types/UTF8StringSuite.java | 54 +++--- .../expressions/stringExpressions.scala | 87 ++++----- .../expressions/StringExpressionsSuite.scala | 54 +++--- .../org/apache/spark/sql/functions.scala | 16 +- .../spark/sql/StringFunctionsSuite.scala | 8 + 6 files changed, 195 insertions(+), 197 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 0509902ee1029..aeaaa3bc00dde 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -511,57 +511,63 @@ public UTF8String trim() { } /** - * Removes all specified trim character either from the beginning or the ending of a string - * @param trimChar the trim character + * Removes all specified trim character string either from the beginning or the ending of a string + * @param trimString the trim character string */ - public UTF8String trim(UTF8String trimChar) { - int numTrimBytes = trimChar.numBytes; - if (numTrimBytes == 0) { - return this; + public UTF8String trim(UTF8String trimString) { + // this method do the trimLeft first, then trimRight + int s = 0; // the searching byte position of the input string + int i = 0; // the first beginning byte position of a non-matching character + int e = 0; // the last byte position + int numChars = 0; // number of characters from the input string + int[] stringCharLen = new int[numBytes]; // array of character length for the input string + int[] stringCharPos = new int[numBytes]; // array of the first byte position for each character in the input string + int searchCharBytes; + + while (s < this.numBytes) { + UTF8String searchChar = copyUTF8String(s, s + numBytesForFirstByte(this.getByte(s)) - 1); + searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set + if (trimString.find(searchChar, 0) >= 0) { + i += searchCharBytes; + } else { + // no matching, exit the search + break; + } + s += searchCharBytes; } - int s = 0; - // e is the search index in the input string, starting with the trim character's bytes - // boundary, moving from right to left. - int e = this.numBytes - numTrimBytes; - // skip all the consecutive matching characters from left to right. - while (s < this.numBytes && s == this.find(trimChar, s)) { - s += numTrimBytes; - } - // skip all the consecutive matching character in the right side. - // if the trimming character has more bytes than the input string, 'e' points to the end - // of input string. - // The search index 'e' will be first positioned at the offset from the end of the input - // string by the number of bytes of the trimming character, if a matching is found, continue - // moving left until the string is exhausted or a non-matching character is hit. Every move - // is the number of bytes of the trimming character. When a non-matching character is hit, - // 'e' needs to be positioned back to the last byte of the non-matching character. - // example 1: - // trim character: 数, input string: 头, both character has 3 bytes. e starts - // 0, rfind could not find matching, index 'e' goes back to the last byte of - // no matching position. - // example 2: - // trim character: 数, input string a, 'a' has 1 byte, '数' has 3 bytes, e starts with -2, - // it should return with the input string - // example 3: - // trim character: 数, input string aaa数, 'aaa数' has 6 bytes, '数' has 3 bytes, e start with - // 3, find matching, move 3 bytes to position 0, didn't find matching, the index e goes back - // to the last byte of no matching position. - if (e < 0) { - e = this.numBytes - 1; + + if (i >= this.numBytes) { + // empty string + return UTF8String.EMPTY_UTF8; } else { - while (e >= 0 && e == this.rfind(trimChar, e)) { - e -= numTrimBytes; + //build the position and length array + s = 0; + while (s < numBytes) { + stringCharPos[numChars] = s; + stringCharLen[numChars]= numBytesForFirstByte(getByte(s)); + s += stringCharLen[numChars]; + numChars ++; } - if (e >= 0) { - e += numTrimBytes - 1; + + e = this.numBytes - 1; + while (numChars > 0) { + UTF8String searchChar = + copyUTF8String(stringCharPos[numChars-1], stringCharPos[numChars-1] + stringCharLen[numChars-1] - 1); + if (trimString.find(searchChar, 0) >= 0) { + e -= stringCharLen[numChars-1]; + } else { + break; + } + numChars --; } } - if (s > e) { + if (i > e) { // empty string return UTF8String.EMPTY_UTF8; } else { - return copyUTF8String(s, e); + return copyUTF8String(i, e); } } @@ -578,24 +584,34 @@ public UTF8String trimLeft() { } /** - * Removes all specified trim character from the beginning of a string - * @param trimChar the trim character + * Removes all specified trim characters from the beginning of a string + * @param trimString the trim character string */ - public UTF8String trimLeft(UTF8String trimChar) { - int numTrimBytes = trimChar.numBytes; - if (numTrimBytes == 0) { - return this; - } - int s = 0; - // skip all the consecutive matching character in the left side - while(s < this.numBytes && s == this.find(trimChar, s)) { - s += numTrimBytes; + public UTF8String trimLeft(UTF8String trimString) { + // this method will get one character from the input string, try to find the the matching character from + // the trimString set. + int s = 0; // the searching byte position of the input string + int i = 0; // the first beginning byte position of a non-matching character + int searchCharBytes; + + while (s < this.numBytes) { + UTF8String searchChar = copyUTF8String(s, s + numBytesForFirstByte(this.getByte(s)) - 1); + searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set + if (trimString.find(searchChar, 0) >= 0) { + i += searchCharBytes; + } else { + // no matching, exit the search + break; + } + s += searchCharBytes; } - if (s == this.numBytes) { + + if (i >= this.numBytes) { // empty string return UTF8String.EMPTY_UTF8; } else { - return copyUTF8String(s, this.numBytes - 1); + return copyUTF8String(i, this.numBytes -1); } } @@ -614,33 +630,44 @@ public UTF8String trimRight() { /** * Removes all specified trim character from the ending of a string - * @param trimChar the trim character + * @param trimString the trim character string */ - public UTF8String trimRight(UTF8String trimChar) { - int numTrimBytes = trimChar.numBytes; - if (numTrimBytes == 0) { - return this; - } - int e = this.numBytes - numTrimBytes; - // skip all the consecutive matching character in the right side - // index 'e' points to first no matching byte position in the input string from right side. - // Index 'e' moves the number of bytes of the trimming character first. - if (e < 0) { - e = this.numBytes - 1; - } else { - while (e >= 0 && e == this.rfind(trimChar, e)) { - e -= numTrimBytes; - } - if (e >= 0) { - e += numTrimBytes - 1; + public UTF8String trimRight(UTF8String trimString) { + // this method will get one character from the input string from right to left, then try to find + // the matching character from the trimString set + + // index e points to first no matching byte position in the input string from right side, + // it moves the number of bytes of the trimming character first. + int e; + int i = 0; + int numChars = 0; // number of characters from the input string + int[] stringCharLen = new int[numBytes]; // array of character length for the input string + int[] stringCharPos = new int[numBytes]; // array of the first byte position for each character in the input string + //build the position and length array + while (i < numBytes) { + stringCharPos[numChars] = i; + stringCharLen[numChars]= numBytesForFirstByte(getByte(i)); + i += stringCharLen[numChars]; + numChars ++; + } + + e = this.numBytes - 1; + while (numChars > 0) { + UTF8String searchChar = + copyUTF8String(stringCharPos[numChars-1], stringCharPos[numChars-1] + stringCharLen[numChars-1] - 1); + if (trimString.find(searchChar, 0) >= 0) { + e -= stringCharLen[numChars-1]; + } else { + break; } + numChars --; } if (e < 0) { // empty string return UTF8String.EMPTY_UTF8; } else { - return copyUTF8String(0, e); + return copyUTF8String(0,e); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index eb6a21b430fd5..8e78f02dd8f4e 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -728,57 +728,47 @@ public void testToLong() throws IOException { } @Test public void trimsChar() { - assertEquals(fromString(" hello "), fromString(" hello ").trim(fromString(""))); assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); - assertEquals(fromString("he"), fromString("ooheooo").trim(fromString("o"))); - assertEquals(fromString(""), fromString("ooooooo").trim(fromString("o"))); - assertEquals(fromString("b"), fromString("b").trim(fromString("o"))); - assertEquals(fromString(" "), fromString(" ooooooo").trim(fromString("o"))); + assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); + assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); + assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); + assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); - assertEquals(fromString("hello"), fromString("oohello").trimLeft(fromString("o"))); - assertEquals(fromString(" "), fromString("oooo ").trimLeft(fromString("o"))); + assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); + assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm"))); assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); - assertEquals(fromString("b"), fromString("b").trimRight(fromString("a"))); - assertEquals(fromString("ab"), fromString("ab").trimRight(fromString("a"))); - assertEquals(fromString(" hello"), fromString(" hello ").trimRight(fromString(" "))); - assertEquals(fromString("oohell"), fromString("oohelloooo").trimRight(fromString("o"))); - assertEquals(fromString(" oohello "), fromString(" oohello ").trimRight(fromString("o"))); - assertEquals(fromString(" oohell"), fromString(" oohelloo").trimRight(fromString("o"))); - assertEquals(fromString(""), fromString("ooooooo").trimRight(fromString("o"))); + assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); + assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); + assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); + assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); - assertEquals(fromString("数"), fromString("数").trim(fromString("a"))); - assertEquals(fromString("a"), fromString("a").trim(fromString("数"))); - assertEquals(fromString(""), fromString("数数数数数").trim(fromString("数"))); - assertEquals(fromString("据砖头"), fromString("数数数据砖头数数").trim(fromString("数"))); + assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); + assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); + assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); + assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); - assertEquals(fromString(" 数数数据砖头"), fromString(" 数数数据砖头数数").trim(fromString("数"))); - assertEquals(fromString("a数数数据砖头数数a"), fromString("a数数数据砖头数数a").trim(fromString("数"))); assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); - assertEquals(fromString("据砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("数"))); - assertEquals(fromString(" 数数数据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数"))); - assertEquals(fromString("数数数据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a"))); + assertEquals(fromString("砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("据数"))); + assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 "))); + assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); + assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); - assertEquals(fromString("数"), fromString("数").trimRight(fromString("a"))); - assertEquals(fromString("a"), fromString("a").trimRight(fromString("数"))); - assertEquals(fromString("头"), fromString("头").trimRight(fromString("数"))); - assertEquals(fromString("头"), fromString("头数数数").trimRight(fromString("数"))); - assertEquals(fromString("数数数据砖头"), fromString("数数数据砖头数数").trimRight(fromString("数"))); - assertEquals(fromString("数数数据砖头数数 "), fromString("数数数据砖头数数 ").trimRight(fromString("数"))); - assertEquals(fromString("aa数数数"), fromString("aa数数数aaa").trimRight(fromString("a"))); - assertEquals(fromString("数数"), fromString("数数").trimRight(fromString("a"))); - assertEquals(fromString("数数aa"), fromString("数数aa").trimRight(fromString("数"))); + assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数"))); + assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab"))); + assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); + assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 289326a7debcf..c00aeb0260e41 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -505,31 +505,31 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } /** - * A function that trim the spaces or a character from both ends for the specified string. + * A function that trim the spaces or a trim string from both ends for the specified string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. - _FUNC_(BOTH trimChar FROM str) - Remove the leading and trailing trimChar from `str` - _FUNC_(LEADING trimChar FROM str) - Remove the leading trimChar from `str` - _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimChar from `str` + _FUNC_(BOTH trimString FROM str) - Remove the leading and trailing trimString from `str` + _FUNC_(LEADING trimChar FROM str) - Remove the leading trimString from `str` + _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimString from `str` """, extended = """ Arguments: str - a string expression - trimChar - the trim character - BOTH, FROM - these are keyword to specify for trim character from both side of the string - LEADING, FROM - these are keyword to specify for trim character from left side of the string - TRAILING, FROM - these are keyword to specify for trim character from right side of the string + trimString - the trim string + BOTH, FROM - these are keyword to specify for trim string from both side of the string + LEADING, FROM - these are keyword to specify for trim string from left side of the string + TRAILING, FROM - these are keyword to specify for trim string from right side of the string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL - > SELECT _FUNC_(BOTH 'S' FROM 'SSparkSQLS'); - parkSQL - > SELECT _FUNC_(LEADING 'S' FROM 'SSparkSQLS'); - parkSQLS - > SELECT _FUNC_(TRAILING 'S' FROM 'SSparkSQLS'); - SSparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ + > SELECT _FUNC_(LEADING 'paS' FROM 'SSparkSQLS'); + rkSQLS + > SELECT _FUNC_(TRAILING 'SLQ' FROM 'SSparkSQLS'); + SSparkS """) case class StringTrim(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -551,22 +551,15 @@ case class StringTrim(children: Seq[Expression]) if (children.size == 1) { return inputs(0).trim() } else if (inputs(1) != null) { - if (inputs(0).numChars > 1) { - throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than " + - s"1 character.") - } else { - return inputs(1).trim(inputs(0)) - } + return inputs(1).trim(inputs(0)) } } null } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && - (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { - throw new AnalysisException(s"The trimming parameter should be Literal " + - s"and only one character.") } + if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => @@ -600,22 +593,22 @@ case class StringTrim(children: Seq[Expression]) } /** - * A function that trim the spaces or a character from left end for given string. + * A function that trim the spaces or a trim string from left end for given string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading space characters from `str`. - _FUNC_(trimChar, str) - Remove the leading trimChar from `str` + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string from the `str` """, extended = """ Arguments: str - a string expression - trimChar - the trim character + trimStr - the trim string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL - > SELECT _FUNC_('S', 'SSparkSQLS'); - parkSQLS + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) case class StringTrimLeft(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -637,22 +630,15 @@ case class StringTrimLeft(children: Seq[Expression]) if (children.size == 1) { return inputs(0).trimLeft() } else if (inputs(1) != null) { - if (inputs(0).numChars > 1) { - throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than" + - s" 1 character.") - } else { - return inputs(1).trimLeft(inputs(0)) - } + return inputs(1).trimLeft(inputs(0)) } } null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && - (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { - throw new AnalysisException(s"The trimming parameter should be Literal " + - s"and only one character.") } + if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => @@ -687,22 +673,22 @@ case class StringTrimLeft(children: Seq[Expression]) } /** - * A function that trim the spaces or a character from right end for given string. + * A function that trim the spaces or a trim string from right end for given string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the trailing space characters from `str`. - _FUNC_(trimChar, str) - Remove the trailing trimChar from `str` + _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` """, extended = """ Arguments: str - a string expression - trimChar - the trim character + trimStr - the trim string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL - > SELECT _FUNC_('S', 'SSparkSQLS'); - SSparkSQL + > SELECT _FUNC_('LQSa', 'SSparkSQLS'); + SSpark """) case class StringTrimRight(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -724,22 +710,15 @@ case class StringTrimRight(children: Seq[Expression]) if (children.size == 1) { return inputs(0).trimRight() } else if (inputs(1) != null) { - if (inputs(0).numChars > 1) { - throw new AnalysisException(s"Trim character '${inputs(0)}' can not be greater than" + - s" 1 character.") - } else { - return inputs(1).trimRight(inputs(0)) + return inputs(1).trimRight(inputs(0)) } } - } null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && - (! children(0).isInstanceOf[Literal] || children(0).toString.length > 1)) { - throw new AnalysisException(s"The trimming parameter should be Literal " + - s"and only one character.") } + if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 57b4b73dff46c..a1eba96764d0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -409,58 +409,52 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s = 'a.string.at(0) checkEvaluation(StringTrim(Seq(Literal(" aa "))), "aa", create_row(" abdef ")) checkEvaluation(StringTrim(Seq("a", Literal("aa"))), "", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal(" aa"))), " ", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal("aa "))), " ", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal("aabbaaaa"))), "bb", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal("aabbaaaa "))), "bbaaaa ", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("ab cd", Literal(" aabbtrimccc"))), "trim", create_row("bdef")) + checkEvaluation(StringTrim(Seq("a.,@<>", Literal("a@>.,>"))), " ", create_row(" abdef ")) checkEvaluation(StringTrim(Seq(s)), "abdef", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("abdefa")) + checkEvaluation(StringTrim(Seq("abd", s)), "ef", create_row("abdefa")) checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("aaabdefaaaa")) - checkEvaluation(StringTrim(Seq("S", s)), "parkSQL", create_row("SSparkSQLS")) + checkEvaluation(StringTrim(Seq("SLSQ", s)), "park", create_row("SSparkSQLS")) checkEvaluation(StringTrimLeft(Seq(Literal(" aa "))), "aa ", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Seq("a", Literal("aa"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("a", Literal("aa "))), " ", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("a", Literal("aabbaaaa"))), "bbaaaa", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("a ", Literal("aa "))), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Seq("ab", Literal("aabbcaaaa"))), "caaaa", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Seq(s)), "abdef ", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Seq("a", s)), "bdefa", create_row("abdefa")) - checkEvaluation(StringTrimLeft(Seq("a", s)), " aaabdefaaaa", create_row(" aaabdefaaaa")) - checkEvaluation(StringTrimLeft(Seq("S", s)), "parkSQLS", create_row("SSparkSQLS")) + checkEvaluation(StringTrimLeft(Seq("a ", s)), "bdefaaaa", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimLeft(Seq("Spk", s)), "arkSQLS", create_row("SSparkSQLS")) checkEvaluation(StringTrimRight(Seq(Literal(" aa "))), " aa", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq("a", Literal("a"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a", Literal("aa"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a", Literal("aabbaaaa"))), "aabb", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("ab", Literal("ab"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("a %", Literal("aabbaaaa %"))), "aabb", create_row("def")) checkEvaluation(StringTrimRight(Seq(s)), " abdef", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq("a", s)), "abdef", create_row("abdefa")) - checkEvaluation(StringTrimRight(Seq("a", s)), " aaabdef", create_row(" aaabdefaaaa")) - checkEvaluation(StringTrimRight(Seq("S", s)), "SSparkSQL", create_row("SSparkSQLS")) + checkEvaluation(StringTrimRight(Seq("abf de", s)), "", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimRight(Seq("S*&", s)), "SSparkSQL", create_row("SSparkSQLS*")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. checkEvaluation(StringTrimRight(Seq("花", Literal("a"))), "a", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq("a", Literal("花"))), "花", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("花", Literal("花"))), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Seq("界花世", Literal("花花世界"))), "", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq(s)), " 花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrimRight(Seq("花", s)), "花花世界", create_row("花花世界花花")) + checkEvaluation(StringTrimRight(Seq("花a#", s)), "花花世界", create_row("花花世界花花###aa花")) checkEvaluation(StringTrimRight(Seq("花", s)), "", create_row("花花花花")) - checkEvaluation(StringTrimRight(Seq("花", s)), " 花花世界花花 ", create_row(" 花花世界花花 ")) - checkEvaluation(StringTrimRight(Seq("a", s)), "aa花花世界花花", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrimRight(Seq("a", s)), "aa花花世界花花", create_row("aa花花世界花花")) + checkEvaluation(StringTrimRight(Seq("花 界b@", s)), " 花花世", create_row(" 花花世 b界@花花 ")) checkEvaluation(StringTrimLeft(Seq(s)), "花花世界 ", create_row(" 花花世界 ")) checkEvaluation(StringTrimLeft(Seq("花", s)), "世界花花", create_row("花花世界花花")) - checkEvaluation(StringTrimLeft(Seq("花", s)), " 花花世界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("花 世", s)), "界花花", create_row(" 花花世界花花")) checkEvaluation(StringTrimLeft(Seq("花", s)), "a花花世界花花 ", create_row("a花花世界花花 ")) - checkEvaluation(StringTrimLeft(Seq("a", s)), "花花世界花花aa", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrimLeft(Seq("a", s)), "花花世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("a花界", s)), "世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(Seq("a世界", s)), "花花世界花花", create_row("花花世界花花")) checkEvaluation(StringTrim(Seq(s)), "花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(Seq("花", s)), "世界", create_row("花花世界花花")) - checkEvaluation(StringTrim(Seq("花", s)), " 花花世界", create_row(" 花花世界花花")) - checkEvaluation(StringTrim(Seq("花", s)), " 花花世界花花 ", create_row(" 花花世界花花 ")) - checkEvaluation(StringTrim(Seq("a", s)), "花花世界花花", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrim(Seq("a", s)), "花花世界花花", create_row("aa花花世界花花")) - checkEvaluation(StringTrim(Seq("花", Literal("花"))), "", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("花", Literal("a"))), "a", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal("花"))), "花", create_row(" abdef ")) + checkEvaluation(StringTrim(Seq("花世界", s)), "", create_row("花花世界花花")) + checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花 花 世界 花 花 ")) + checkEvaluation(StringTrim(Seq("a花世", s)), "界", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(Seq("a@#( )", s)), "花花世界花花", create_row("aa()花花世界花花@ #")) + checkEvaluation(StringTrim(Seq("花 ", Literal("花trim"))), "trim", create_row(" abdef ")) // scalastyle:on checkEvaluation(StringTrim(Seq((Literal("a")), (Literal.create(null, StringType)))), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91a547839fb..056b6dcbd4e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2334,12 +2334,12 @@ object functions { def ltrim(e: Column): Column = withExpr {StringTrimLeft(Seq(e.expr))} /** - * Trim the specified character from left ends for the specified string column. + * Trim the specified character string from left ends for the specified string column. * @group string_funcs * @since 2.0.0 */ - def ltrim(e: Column, trimChar: String): Column = - withExpr { StringTrimLeft(Seq(Literal(trimChar), e.expr))} + def ltrim(e: Column, trimString: String): Column = + withExpr { StringTrimLeft(Seq(Literal(trimString), e.expr))} /** * Extract a specific group matched by a Java regex, from the specified string column. @@ -2419,12 +2419,12 @@ object functions { def rtrim(e: Column): Column = withExpr { StringTrimRight(Seq(e.expr)) } /** - * Trim the specified character from right ends for the specified string column. + * Trim the specified character string from right ends for the specified string column. * @group string_funcs * @since 2.0.0 */ - def rtrim(e: Column, trimChar: String): Column = - withExpr { StringTrimRight(Seq(Literal(trimChar), e.expr))} + def rtrim(e: Column, trimString: String): Column = + withExpr { StringTrimRight(Seq(Literal(trimString), e.expr))} /** * Returns the soundex code for the specified expression. @@ -2498,8 +2498,8 @@ object functions { * @group string_funcs * @since 2.0.0 */ - def trim(trimChar: String, e: Column): Column = - withExpr { StringTrim(Seq(Literal(trimChar), e.expr))} + def trim(trimString: String, e: Column): Column = + withExpr { StringTrim(Seq(Literal(trimString), e.expr))} /** * Converts a string column to upper case. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 704557d9741ad..21f209c89b870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -171,6 +171,14 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim("e", $"c")), Row("xample", "exampl", "xampl")) + checkAnswer( + df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim("elxp", $"c")), + Row("ample", "exa", "am")) + + checkAnswer( + df.select(trim("xyz", $"c")), + Row("example")) + checkAnswer( df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), Row("example ", " example", "example")) From d73935504ae24c9c9d9d971bc6d31d4de19b6eb4 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 16 May 2017 13:02:08 -0700 Subject: [PATCH 07/21] address comments --- .../expressions/stringExpressions.scala | 65 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 14 ++-- .../org/apache/spark/sql/functions.scala | 10 +-- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c00aeb0260e41..c2d74bba50f5a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -505,7 +505,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } /** - * A function that trim the spaces or a trim string from both ends for the specified string. + * A function that trims leading or trailing characters (or both) from the specified string. */ @ExpressionDescription( usage = """ @@ -518,9 +518,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi Arguments: str - a string expression trimString - the trim string - BOTH, FROM - these are keyword to specify for trim string from both side of the string - LEADING, FROM - these are keyword to specify for trim string from left side of the string - TRAILING, FROM - these are keyword to specify for trim string from right side of the string + BOTH, FROM - these are keyword to specify for trim string from both ends of the string + LEADING, FROM - these are keyword to specify for trim string from left end of the string + TRAILING, FROM - these are keyword to specify for trim string from right end of the string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL @@ -558,7 +558,7 @@ case class StringTrim(children: Seq[Expression]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + if (children.size == 2 && !children(0).isInstanceOf[Literal]) { throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) @@ -566,17 +566,16 @@ case class StringTrim(children: Seq[Expression]) s"${eval.isNull} ? null : ${eval.value}" } val getTrimFunction = if (children.size == 1) { - s"""UTF8String ${ev.value} = ${inputs(0)}.trim();""" + s"UTF8String ${ev.value} = ${inputs(0)}.trim();" } else { - s"""UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});""".stripMargin - } - ev.copy(evals.map(_.code).mkString("\n") + - s""" - boolean ${ev.isNull} = false; - ${getTrimFunction}; - if (${ev.value} == null) { - ${ev.isNull} = true; + s"UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});".stripMargin } + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + ${getTrimFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } """) } @@ -637,7 +636,7 @@ case class StringTrimLeft(children: Seq[Expression]) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + if (children.size == 2 && !children(0).isInstanceOf[Literal]) { throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) @@ -645,18 +644,17 @@ case class StringTrimLeft(children: Seq[Expression]) s"${eval.isNull} ? null : ${eval.value}" } val getTrimLeftFunction = if (children.size == 1) { - s"""UTF8String ${ev.value} = ${inputs(0)}.trimLeft();""" + s"UTF8String ${ev.value} = ${inputs(0)}.trimLeft();" } else { - s"""UTF8String ${ev.value} = ${inputs(1)}.trimLeft(${inputs(0)});""" + s"UTF8String ${ev.value} = ${inputs(1)}.trimLeft(${inputs(0)});" } - ev.copy(evals.map(_.code).mkString("\n") + - s""" - boolean ${ev.isNull} = false; - ${getTrimLeftFunction}; - if (${ev.value} == null) { - ${ev.isNull} = true; - } + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + ${getTrimLeftFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } """) } @@ -717,7 +715,7 @@ case class StringTrimRight(children: Seq[Expression]) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && ! children(0).isInstanceOf[Literal]) { + if (children.size == 2 && !children(0).isInstanceOf[Literal]) { throw new AnalysisException(s"The trimming parameter should be Literal.")} val evals = children.map(_.genCode(ctx)) @@ -725,17 +723,16 @@ case class StringTrimRight(children: Seq[Expression]) s"${eval.isNull} ? null : ${eval.value}" } val getTrimRightFunction = if (children.size == 1) { - s"""UTF8String ${ev.value} = ${inputs(0)}.trimRight();""" + s"UTF8String ${ev.value} = ${inputs(0)}.trimRight();" } else { - s"""UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});""" - } - ev.copy(evals.map(_.code).mkString("\n") + - s""" - boolean ${ev.isNull} = false; - ${getTrimRightFunction}; - if (${ev.value} == null) { - ${ev.isNull} = true; + s"UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});" } + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + ${getTrimRightFunction}; + if (${ev.value} == null) { + ${ev.isNull} = true; + } """) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7c19571a3e22d..2c2cec2a83913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1217,8 +1217,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.BOTH => "trim" case SqlBaseParser.LEADING => "ltrim" case SqlBaseParser.TRAILING => "rtrim" - case _ => throw new ParseException(s"Function trim doesn't support " + - s"this ${opt.getType}.", ctx) + case _ => throw new ParseException(s"Function trim doesn't support with" + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) } } @@ -1240,17 +1240,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ protected def visitFunctionName( ctx: QualifiedNameContext, - trimFuncN: Option[String] = None): FunctionIdentifier = { + trimFuncName: Option[String] = None): FunctionIdentifier = { ctx.identifier().asScala.map(_.getText) match { case Seq(db, fn) => - if (fn.equalsIgnoreCase("trim") && trimFuncN.isDefined) { - FunctionIdentifier(trimFuncN.get, Option(db)) + if (trimFuncName.isDefined) { + FunctionIdentifier(trimFuncName.get, Option(db)) } else { FunctionIdentifier(fn, Option(db)) } case Seq(fn) => - if (fn.equalsIgnoreCase("trim") && trimFuncN.isDefined) { - FunctionIdentifier(trimFuncN.get, None) + if (trimFuncName.isDefined) { + FunctionIdentifier(trimFuncName.get, None) } else { FunctionIdentifier(fn, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 056b6dcbd4e81..aa8f7e6a90ea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2334,9 +2334,9 @@ object functions { def ltrim(e: Column): Column = withExpr {StringTrimLeft(Seq(e.expr))} /** - * Trim the specified character string from left ends for the specified string column. + * Trim the specified character string from left end for the specified string column. * @group string_funcs - * @since 2.0.0 + * @since 2.2.0 */ def ltrim(e: Column, trimString: String): Column = withExpr { StringTrimLeft(Seq(Literal(trimString), e.expr))} @@ -2419,9 +2419,9 @@ object functions { def rtrim(e: Column): Column = withExpr { StringTrimRight(Seq(e.expr)) } /** - * Trim the specified character string from right ends for the specified string column. + * Trim the specified character string from right end for the specified string column. * @group string_funcs - * @since 2.0.0 + * @since 2.2.0 */ def rtrim(e: Column, trimString: String): Column = withExpr { StringTrimRight(Seq(Literal(trimString), e.expr))} @@ -2496,7 +2496,7 @@ object functions { /** * Trim the specified character from both ends for the specified string column. * @group string_funcs - * @since 2.0.0 + * @since 2.2.0 */ def trim(trimString: String, e: Column): Column = withExpr { StringTrim(Seq(Literal(trimString), e.expr))} From 15a63d18d80ec041deb9cc8c38fe1eab9664de24 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 17 May 2017 23:00:27 -0700 Subject: [PATCH 08/21] address comments --- .../apache/spark/unsafe/types/UTF8String.java | 18 +++++---- .../expressions/stringExpressions.scala | 37 +++++++++++++++---- .../sql/catalyst/parser/AstBuilder.scala | 3 +- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index aeaaa3bc00dde..16f84cb2d37fb 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -511,11 +511,13 @@ public UTF8String trim() { } /** - * Removes all specified trim character string either from the beginning or the ending of a string + * Removes the given trim string from both ends of a string * @param trimString the trim character string */ public UTF8String trim(UTF8String trimString) { - // this method do the trimLeft first, then trimRight + // This method searches for each character in the source string, removes the character if it is found + // in the trim string, stops at the first not found. It starts from left end, then right end. + // It returns a new string in which both ends trim characters have been removed. int s = 0; // the searching byte position of the input string int i = 0; // the first beginning byte position of a non-matching character int e = 0; // the last byte position @@ -584,12 +586,12 @@ public UTF8String trimLeft() { } /** - * Removes all specified trim characters from the beginning of a string + * Removes the given trim string from the beginning of a string * @param trimString the trim character string */ public UTF8String trimLeft(UTF8String trimString) { - // this method will get one character from the input string, try to find the the matching character from - // the trimString set. + // this method searches each character in the source string starting from the left end, removes the character if it + // is in the trim string, stops at the first character which is not in the trim string, returns the new string. int s = 0; // the searching byte position of the input string int i = 0; // the first beginning byte position of a non-matching character int searchCharBytes; @@ -629,12 +631,12 @@ public UTF8String trimRight() { } /** - * Removes all specified trim character from the ending of a string + * Removes the given trim string from the ending of a string * @param trimString the trim character string */ public UTF8String trimRight(UTF8String trimString) { - // this method will get one character from the input string from right to left, then try to find - // the matching character from the trimString set + // this method searches each character in the source string starting from the right end, removes the character if it + // is in the trim string, stops at the first character which is not in the trim string, returns the new string. // index e points to first no matching byte position in the input string from right side, // it moves the number of bytes of the trimming character first. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c2d74bba50f5a..c611b5ab99d05 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -505,14 +505,22 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } /** - * A function that trims leading or trailing characters (or both) from the specified string. - */ + * A function that takes a character string, removes the leading and/or trailing characters matching with the characters + * in the trim string, returns the new string. If LEADING/TRAILING/BOTH and trimStr keywords are not specified, it + * defaults to remove space character from both ends. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. + * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. + * BOTH: removes any characters from both ends of the source string that matches characters in the trim string. + */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. - _FUNC_(BOTH trimString FROM str) - Remove the leading and trailing trimString from `str` - _FUNC_(LEADING trimChar FROM str) - Remove the leading trimString from `str` - _FUNC_(TRAILING trimChar FROM str) - Remove the trailing trimString from `str` + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + _FUNC_(LEADING trimStr FROM str) - Remove the leading trimString from `str` + _FUNC_(TRAILING trimStr FROM str) - Remove the trailing trimString from `str` """, extended = """ Arguments: @@ -545,6 +553,9 @@ case class StringTrim(children: Seq[Expression]) override def prettyName: String = "trim" + // trim function can take one or two arguments. + // For one argument(children size is 1), it is the trim space function. + // For two arguments(children size is 2), it is the trim function with one of these options: BOTH/LEADING/TRAILING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -592,7 +603,10 @@ case class StringTrim(children: Seq[Expression]) } /** - * A function that trim the spaces or a trim string from left end for given string. + * A function that trims the characters from left end for a given string, if the trimStr is not specified, it defaults + * to trim the spaces from the left end of the source string. + * trimStr: the function removes any characters from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. */ @ExpressionDescription( usage = """ @@ -623,6 +637,9 @@ case class StringTrimLeft(children: Seq[Expression]) override def prettyName: String = "ltrim" + // ltrim function can take one or two arguments. + // For one argument(children size is 1), it is the ltrim space function. + // For two arguments(children size is 2), it is the trim function with option LEADING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -671,7 +688,10 @@ case class StringTrimLeft(children: Seq[Expression]) } /** - * A function that trim the spaces or a trim string from right end for given string. + * A function that trims the characters from right end for a given string, if the trimStr is not specified, it defaults + * to trim the spaces from the right end of the source string. + * trimStr: the function removes any characters from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. */ @ExpressionDescription( usage = """ @@ -702,6 +722,9 @@ case class StringTrimRight(children: Seq[Expression]) override def prettyName: String = "rtrim" + // rtrim function can take one or two arguments. + // For one argument(children size is 1), it is the rtrim space function. + // For two arguments(children size is 2), it is the trim function with option TRAILING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2c2cec2a83913..b2eb1a3d9b778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1182,7 +1182,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Create the function call. val name = ctx.qualifiedName.getText val trimFuncName = Option(ctx.trimOperator).map { - o => visitTrimFuncName(ctx, o)} + o => visitTrimFuncName(ctx, o) + } val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) From 10dab6f396832d2ce1116261fabb49a1a7ecab41 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 17 May 2017 23:10:59 -0700 Subject: [PATCH 09/21] split trim test in UTF8StringSuite --- .../spark/unsafe/types/UTF8StringSuite.java | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 8e78f02dd8f4e..11338fed67784 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -727,12 +727,27 @@ public void testToLong() throws IOException { } } @Test - public void trimsChar() { + public void trim() { assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); + + + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); + + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); + assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); + assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); + assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); + assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); + } + + @Test + public void trimLeft() { assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); @@ -741,23 +756,9 @@ public void trimsChar() { assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm"))); - assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); - assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); - assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); - assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); - assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); - assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); - assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); - assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); - assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); - assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); - assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); - assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); - assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); - assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); @@ -765,6 +766,18 @@ public void trimsChar() { assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 "))); assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); + } + @Test + public void trimRight() { + assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); + assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); + assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); + assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); + assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数"))); assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab"))); From 04f0c10dffdadb8f492319a4c8cd8f70997954ac Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 23 May 2017 20:43:50 -0700 Subject: [PATCH 10/21] adjust codes based on comments --- .../apache/spark/unsafe/types/UTF8String.java | 120 +++++------------- .../spark/unsafe/types/UTF8StringSuite.java | 2 - .../expressions/StringExpressionsSuite.scala | 51 +++++--- 3 files changed, 66 insertions(+), 107 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 16f84cb2d37fb..a09315931e194 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -512,65 +512,13 @@ public UTF8String trim() { /** * Removes the given trim string from both ends of a string + * This method searches for each character in the source string, removes the character if it is found + * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight. + * It returns a new string in which both ends trim characters have been removed. * @param trimString the trim character string */ public UTF8String trim(UTF8String trimString) { - // This method searches for each character in the source string, removes the character if it is found - // in the trim string, stops at the first not found. It starts from left end, then right end. - // It returns a new string in which both ends trim characters have been removed. - int s = 0; // the searching byte position of the input string - int i = 0; // the first beginning byte position of a non-matching character - int e = 0; // the last byte position - int numChars = 0; // number of characters from the input string - int[] stringCharLen = new int[numBytes]; // array of character length for the input string - int[] stringCharPos = new int[numBytes]; // array of the first byte position for each character in the input string - int searchCharBytes; - - while (s < this.numBytes) { - UTF8String searchChar = copyUTF8String(s, s + numBytesForFirstByte(this.getByte(s)) - 1); - searchCharBytes = searchChar.numBytes; - // try to find the matching for the searchChar in the trimString set - if (trimString.find(searchChar, 0) >= 0) { - i += searchCharBytes; - } else { - // no matching, exit the search - break; - } - s += searchCharBytes; - } - - if (i >= this.numBytes) { - // empty string - return UTF8String.EMPTY_UTF8; - } else { - //build the position and length array - s = 0; - while (s < numBytes) { - stringCharPos[numChars] = s; - stringCharLen[numChars]= numBytesForFirstByte(getByte(s)); - s += stringCharLen[numChars]; - numChars ++; - } - - e = this.numBytes - 1; - while (numChars > 0) { - UTF8String searchChar = - copyUTF8String(stringCharPos[numChars-1], stringCharPos[numChars-1] + stringCharLen[numChars-1] - 1); - if (trimString.find(searchChar, 0) >= 0) { - e -= stringCharLen[numChars-1]; - } else { - break; - } - numChars --; - } - } - - if (i > e) { - // empty string - return UTF8String.EMPTY_UTF8; - } else { - return copyUTF8String(i, e); - } + return trimLeft(trimString).trimRight(trimString); } public UTF8String trimLeft() { @@ -587,33 +535,32 @@ public UTF8String trimLeft() { /** * Removes the given trim string from the beginning of a string + * This method searches each character in the source string starting from the left end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string */ public UTF8String trimLeft(UTF8String trimString) { - // this method searches each character in the source string starting from the left end, removes the character if it - // is in the trim string, stops at the first character which is not in the trim string, returns the new string. - int s = 0; // the searching byte position of the input string - int i = 0; // the first beginning byte position of a non-matching character - int searchCharBytes; - - while (s < this.numBytes) { - UTF8String searchChar = copyUTF8String(s, s + numBytesForFirstByte(this.getByte(s)) - 1); - searchCharBytes = searchChar.numBytes; + int srchIdx = 0; // the searching byte position of the input string + int trimIdx = 0; // the first beginning byte position of a non-matching character + + while (srchIdx < numBytes) { + UTF8String searchChar = copyUTF8String(srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; // try to find the matching for the searchChar in the trimString set if (trimString.find(searchChar, 0) >= 0) { - i += searchCharBytes; + trimIdx += searchCharBytes; } else { // no matching, exit the search break; } - s += searchCharBytes; + srchIdx += searchCharBytes; } - if (i >= this.numBytes) { + if (trimIdx >= numBytes) { // empty string - return UTF8String.EMPTY_UTF8; + return EMPTY_UTF8; } else { - return copyUTF8String(i, this.numBytes -1); + return copyUTF8String(trimIdx, numBytes -1); } } @@ -632,44 +579,43 @@ public UTF8String trimRight() { /** * Removes the given trim string from the ending of a string + * This method searches each character in the source string starting from the right end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string */ public UTF8String trimRight(UTF8String trimString) { - // this method searches each character in the source string starting from the right end, removes the character if it - // is in the trim string, stops at the first character which is not in the trim string, returns the new string. - - // index e points to first no matching byte position in the input string from right side, - // it moves the number of bytes of the trimming character first. - int e; - int i = 0; + // index trimEnd points to first no matching byte position in the input string from right side, + // it moves the number of bytes of the trimming character. + int trimEnd; + int trimIdx = 0; int numChars = 0; // number of characters from the input string int[] stringCharLen = new int[numBytes]; // array of character length for the input string int[] stringCharPos = new int[numBytes]; // array of the first byte position for each character in the input string //build the position and length array - while (i < numBytes) { - stringCharPos[numChars] = i; - stringCharLen[numChars]= numBytesForFirstByte(getByte(i)); - i += stringCharLen[numChars]; + while (trimIdx < numBytes) { + stringCharPos[numChars] = trimIdx; + stringCharLen[numChars]= numBytesForFirstByte(getByte(trimIdx)); + trimIdx += stringCharLen[numChars]; numChars ++; } - e = this.numBytes - 1; + trimEnd = numBytes - 1; while (numChars > 0) { UTF8String searchChar = - copyUTF8String(stringCharPos[numChars-1], stringCharPos[numChars-1] + stringCharLen[numChars-1] - 1); + copyUTF8String(stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); if (trimString.find(searchChar, 0) >= 0) { - e -= stringCharLen[numChars-1]; + trimEnd -= stringCharLen[numChars - 1]; } else { break; } numChars --; } - if (e < 0) { + if (trimEnd < 0) { // empty string - return UTF8String.EMPTY_UTF8; + return EMPTY_UTF8; } else { - return copyUTF8String(0,e); + return copyUTF8String(0, trimEnd); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 11338fed67784..ea9e5a114aa32 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -734,10 +734,8 @@ public void trim() { assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); - assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); - assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index a1eba96764d0d..908c5ba16340b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -405,7 +405,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } - test("TRIM/LTRIM/RTRIM") { + test("TRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Seq(Literal(" aa "))), "aa", create_row(" abdef ")) checkEvaluation(StringTrim(Seq("a", Literal("aa"))), "", create_row(" abdef ")) @@ -416,6 +416,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("aaabdefaaaa")) checkEvaluation(StringTrim(Seq("SLSQ", s)), "park", create_row("SSparkSQLS")) + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrim(Seq(s)), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(Seq("花世界", s)), "", create_row("花花世界花花")) + checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花 花 世界 花 花 ")) + checkEvaluation(StringTrim(Seq("a花世", s)), "界", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(Seq("a@#( )", s)), "花花世界花花", create_row("aa()花花世界花花@ #")) + checkEvaluation(StringTrim(Seq("花 ", Literal("花trim"))), "trim", create_row(" abdef ")) + // scalastyle:on + checkEvaluation(StringTrim(Seq((Literal("a")), + (Literal.create(null, StringType)))), null) + } + + test("LTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimLeft(Seq(Literal(" aa "))), "aa ", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Seq("a", Literal("aa"))), "", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Seq("a ", Literal("aa "))), "", create_row(" abdef ")) @@ -424,6 +440,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(Seq("a", s)), "bdefa", create_row("abdefa")) checkEvaluation(StringTrimLeft(Seq("a ", s)), "bdefaaaa", create_row(" aaabdefaaaa")) checkEvaluation(StringTrimLeft(Seq("Spk", s)), "arkSQLS", create_row("SSparkSQLS")) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimLeft(Seq(s)), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(Seq("花", s)), "世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("花 世", s)), "界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(Seq("花", s)), "a花花世界花花 ", create_row("a花花世界花花 ")) + checkEvaluation(StringTrimLeft(Seq("a花界", s)), "世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(Seq("a世界", s)), "花花世界花花", create_row("花花世界花花")) + // scalastyle:on + checkEvaluation(StringTrimLeft(Seq((Literal("a")), + (Literal.create(null, StringType)))), null) + } + + test("RTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimRight(Seq(Literal(" aa "))), " aa", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq("a", Literal("a"))), "", create_row(" abdef ")) checkEvaluation(StringTrimRight(Seq("ab", Literal("ab"))), "", create_row(" abdef ")) @@ -442,24 +474,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimRight(Seq("花a#", s)), "花花世界", create_row("花花世界花花###aa花")) checkEvaluation(StringTrimRight(Seq("花", s)), "", create_row("花花花花")) checkEvaluation(StringTrimRight(Seq("花 界b@", s)), " 花花世", create_row(" 花花世 b界@花花 ")) - checkEvaluation(StringTrimLeft(Seq(s)), "花花世界 ", create_row(" 花花世界 ")) - checkEvaluation(StringTrimLeft(Seq("花", s)), "世界花花", create_row("花花世界花花")) - checkEvaluation(StringTrimLeft(Seq("花 世", s)), "界花花", create_row(" 花花世界花花")) - checkEvaluation(StringTrimLeft(Seq("花", s)), "a花花世界花花 ", create_row("a花花世界花花 ")) - checkEvaluation(StringTrimLeft(Seq("a花界", s)), "世界花花aa", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrimLeft(Seq("a世界", s)), "花花世界花花", create_row("花花世界花花")) - checkEvaluation(StringTrim(Seq(s)), "花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(Seq("花世界", s)), "", create_row("花花世界花花")) - checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花花世界花花")) - checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花 花 世界 花 花 ")) - checkEvaluation(StringTrim(Seq("a花世", s)), "界", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrim(Seq("a@#( )", s)), "花花世界花花", create_row("aa()花花世界花花@ #")) - checkEvaluation(StringTrim(Seq("花 ", Literal("花trim"))), "trim", create_row(" abdef ")) // scalastyle:on - checkEvaluation(StringTrim(Seq((Literal("a")), - (Literal.create(null, StringType)))), null) - checkEvaluation(StringTrimLeft(Seq((Literal("a")), - (Literal.create(null, StringType)))), null) checkEvaluation(StringTrimRight(Seq((Literal("a")), (Literal.create(null, StringType)))), null) } From a40e119c5c7fd29f621a92dac0928f2b1554cbfc Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 31 May 2017 16:23:54 -0700 Subject: [PATCH 11/21] adjust code --- .../apache/spark/unsafe/types/UTF8String.java | 41 +++--- .../spark/unsafe/types/UTF8StringSuite.java | 10 +- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../expressions/stringExpressions.scala | 130 +++++++----------- .../sql/catalyst/parser/AstBuilder.scala | 59 ++++---- .../expressions/StringExpressionsSuite.scala | 9 +- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/StringFunctionsSuite.scala | 6 +- 8 files changed, 109 insertions(+), 151 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a09315931e194..dcdf7d6389f93 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -511,7 +511,7 @@ public UTF8String trim() { } /** - * Removes the given trim string from both ends of a string + * Removes the given source string starting from both ends * This method searches for each character in the source string, removes the character if it is found * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight. * It returns a new string in which both ends trim characters have been removed. @@ -534,14 +534,16 @@ public UTF8String trimLeft() { } /** - * Removes the given trim string from the beginning of a string + * Removes the given source string from the left end * This method searches each character in the source string starting from the left end, removes the character if it * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string */ public UTF8String trimLeft(UTF8String trimString) { - int srchIdx = 0; // the searching byte position of the input string - int trimIdx = 0; // the first beginning byte position of a non-matching character + // the searching byte position in the source string + int srchIdx = 0; + // the first beginning byte position of a non-matching character + int trimIdx = 0; while (srchIdx < numBytes) { UTF8String searchChar = copyUTF8String(srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1); @@ -560,7 +562,7 @@ public UTF8String trimLeft(UTF8String trimString) { // empty string return EMPTY_UTF8; } else { - return copyUTF8String(trimIdx, numBytes -1); + return copyUTF8String(trimIdx, numBytes - 1); } } @@ -578,28 +580,29 @@ public UTF8String trimRight() { } /** - * Removes the given trim string from the ending of a string + * Removes the given source string from the right end * This method searches each character in the source string starting from the right end, removes the character if it * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string */ public UTF8String trimRight(UTF8String trimString) { - // index trimEnd points to first no matching byte position in the input string from right side, - // it moves the number of bytes of the trimming character. - int trimEnd; - int trimIdx = 0; - int numChars = 0; // number of characters from the input string - int[] stringCharLen = new int[numBytes]; // array of character length for the input string - int[] stringCharPos = new int[numBytes]; // array of the first byte position for each character in the input string - //build the position and length array - while (trimIdx < numBytes) { - stringCharPos[numChars] = trimIdx; - stringCharLen[numChars]= numBytesForFirstByte(getByte(trimIdx)); - trimIdx += stringCharLen[numChars]; + int charIdx = 0; + // number of characters from the source string + int numChars = 0; + // array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + // build the position and length array + while (charIdx < numBytes) { + stringCharPos[numChars] = charIdx; + stringCharLen[numChars]= numBytesForFirstByte(getByte(charIdx)); + charIdx += stringCharLen[numChars]; numChars ++; } - trimEnd = numBytes - 1; + // index trimEnd points to the first no matching byte position from the right side of the source string. + int trimEnd = numBytes - 1; while (numChars > 0) { UTF8String searchChar = copyUTF8String(stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index ea9e5a114aa32..98fa1f9ec22eb 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -234,6 +234,10 @@ public void trims() { assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + char[] charsLessThan0x20 = new char[10]; Arrays.fill(charsLessThan0x20, (char)(' ' - 1)); String stringStartingWithSpace = @@ -726,6 +730,7 @@ public void testToLong() throws IOException { assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); } } + @Test public void trim() { assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); @@ -746,10 +751,6 @@ public void trim() { @Test public void trimLeft() { - assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); - assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); - assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); - assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); @@ -765,6 +766,7 @@ public void trimLeft() { assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); } + @Test public void trimRight() { assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 96b6cf5a6a0dc..332dbc9355ff9 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -580,9 +580,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall - | qualifiedName '(' trimOperator=(BOTH | LEADING | TRAILING) trimChar=namedExpression + | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) trimChar=namedExpression FROM namedExpression ')' #functionCall - | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c611b5ab99d05..d8e261c1c97ef 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -504,58 +504,63 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends ImplicitCastInputTypes { + self: Expression => + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } + } +} + /** * A function that takes a character string, removes the leading and/or trailing characters matching with the characters - * in the trim string, returns the new string. If LEADING/TRAILING/BOTH and trimStr keywords are not specified, it - * defaults to remove space character from both ends. + * in the trim string, returns the new string. If BOTH and trimStr keywords are not specified, it defaults to remove + * space character from both ends. * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function * searches for each character in the source string, removes the characters from the source string until it * encounters the first non-match character. - * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. - * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. * BOTH: removes any characters from both ends of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` - _FUNC_(LEADING trimStr FROM str) - Remove the leading trimString from `str` - _FUNC_(TRAILING trimStr FROM str) - Remove the trailing trimString from `str` """, extended = """ Arguments: str - a string expression trimString - the trim string BOTH, FROM - these are keyword to specify for trim string from both ends of the string - LEADING, FROM - these are keyword to specify for trim string from left end of the string - TRAILING, FROM - these are keyword to specify for trim string from right end of the string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); parkSQ - > SELECT _FUNC_(LEADING 'paS' FROM 'SSparkSQLS'); - rkSQLS - > SELECT _FUNC_(TRAILING 'SLQ' FROM 'SSparkSQLS'); - SSparkS """) case class StringTrim(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { + extends Expression with String2TrimExpression { require(children.size <= 2 && children.nonEmpty, s"$prettyName requires at least one argument and no more than two.") - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - override def prettyName: String = "trim" // trim function can take one or two arguments. - // For one argument(children size is 1), it is the trim space function. - // For two arguments(children size is 2), it is the trim function with one of these options: BOTH/LEADING/TRAILING. + // Specify one child, it is for the trim space function. + // Specify the two children, it is for the trim function with BOTH option. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -579,39 +584,29 @@ case class StringTrim(children: Seq[Expression]) val getTrimFunction = if (children.size == 1) { s"UTF8String ${ev.value} = ${inputs(0)}.trim();" } else { - s"UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});".stripMargin + s"UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});" } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - ${getTrimFunction}; + boolean ${ev.isNull} = false + $getTrimFunction if (${ev.value} == null) { ${ev.isNull} = true; } """) - } - - override def sql: String = { - if (children.size == 1) { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } else { - val trimSQL = children(0).map(_.sql).mkString(", ") - val tarSQL = children(1).map(_.sql).mkString(", ") - s"$prettyName($trimSQL, $tarSQL)" - } } } /** - * A function that trims the characters from left end for a given string, if the trimStr is not specified, it defaults - * to trim the spaces from the left end of the source string. + * A function that trims the characters from left end for a given string, If LEADING and trimStr keywords are not + * specified, it defaults to remove space character from the left end. * trimStr: the function removes any characters from the left end of the source string which matches with the characters * from trimStr, it stops at the first non-match character. + * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading space characters from `str`. - _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string from the `str` + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string """, extended = """ Arguments: @@ -624,22 +619,16 @@ case class StringTrim(children: Seq[Expression]) arkSQLS """) case class StringTrimLeft(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { + extends Expression with String2TrimExpression { require (children.size <= 2 && children.nonEmpty, "$prettyName requires at least one argument and no more than two.") - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - override def prettyName: String = "ltrim" // ltrim function can take one or two arguments. - // For one argument(children size is 1), it is the ltrim space function. - // For two arguments(children size is 2), it is the trim function with option LEADING. + // Specify one child, it is for the ltrim space function. + // Specify the two children, it is for the trim function with option LEADING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -667,31 +656,21 @@ case class StringTrimLeft(children: Seq[Expression]) } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - ${getTrimLeftFunction}; + boolean ${ev.isNull} = false + $getTrimLeftFunction if (${ev.value} == null) { ${ev.isNull} = true; } """) } - - override def sql: String = { - if (children.size == 1) { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } else { - val trimSQL = children(0).map(_.sql).mkString(", ") - val tarSQL = children(1).map(_.sql).mkString(", ") - s"$prettyName($trimSQL, $tarSQL)" - } - } } /** - * A function that trims the characters from right end for a given string, if the trimStr is not specified, it defaults - * to trim the spaces from the right end of the source string. + * A function that trims the characters from right end for a given string, If TRAILING and trimStr keywords are not + * specified, it defaults to remove space character from the right end. * trimStr: the function removes any characters from the right end of source string which matches with the characters * from trimStr, it stops at the first non-match character. + * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ @@ -709,22 +688,16 @@ case class StringTrimLeft(children: Seq[Expression]) SSpark """) case class StringTrimRight(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { + extends Expression with String2TrimExpression { require (children.size <= 2 && children.nonEmpty, "$prettyName requires at least one argument and no more than two.") - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - override def prettyName: String = "rtrim" // rtrim function can take one or two arguments. - // For one argument(children size is 1), it is the rtrim space function. - // For two arguments(children size is 2), it is the trim function with option TRAILING. + // Specify one child, it is for the rtrim space function. + // Specify the two children, it is for the trim function with option TRAILING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -751,24 +724,13 @@ case class StringTrimRight(children: Seq[Expression]) s"UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});" } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - ${getTrimRightFunction}; + boolean ${ev.isNull} = false + $getTrimRightFunction if (${ev.value} == null) { ${ev.isNull} = true; } """) } - - override def sql: String = { - if (children.size == 1) { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } else { - val trimSQL = children(0).map(_.sql).mkString(", ") - val tarSQL = children(1).map(_.sql).mkString(", ") - s"$prettyName($trimSQL, $tarSQL)" - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b2eb1a3d9b778..7dc82159ae746 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1181,9 +1181,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { // Create the function call. val name = ctx.qualifiedName.getText - val trimFuncName = Option(ctx.trimOperator).map { - o => visitTrimFuncName(ctx, o) - } val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) @@ -1193,8 +1190,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } - val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName, trimFuncName), - arguments, isDistinct) + val function = UnresolvedFunction( + replaceTrimFunction(visitFunctionName(ctx.qualifiedName), ctx), + arguments, + isDistinct) + // Check if the function is evaluated in a windowed context. ctx.windowSpec match { @@ -1207,19 +1207,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create a name LTRIM for TRIM(Leading), RTRIM for TRIM(Trailing), TRIM for TRIM(BOTH) + * Create a function name LTRIM for TRIM(Leading), RTRIM for TRIM(Trailing), TRIM for TRIM(BOTH), + * otherwise, returnthe original funcID. */ - private def visitTrimFuncName(ctx: FunctionCallContext, opt: Token): String = { - if (ctx.qualifiedName.getText.toLowerCase != "trim") { - throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + - s"doesn't support with option ${opt.getText}.", ctx) - } - opt.getType match { - case SqlBaseParser.BOTH => "trim" - case SqlBaseParser.LEADING => "ltrim" - case SqlBaseParser.TRAILING => "rtrim" - case _ => throw new ParseException(s"Function trim doesn't support with" + - s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + private def replaceTrimFunction(funcID: FunctionIdentifier, ctx: FunctionCallContext) + : FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case SqlBaseParser.BOTH => funcID + case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException(s"Function trim doesn't support with" + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID } } @@ -1239,22 +1246,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create a function database (optional) and name pair. */ - protected def visitFunctionName( - ctx: QualifiedNameContext, - trimFuncName: Option[String] = None): FunctionIdentifier = { + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { ctx.identifier().asScala.map(_.getText) match { - case Seq(db, fn) => - if (trimFuncName.isDefined) { - FunctionIdentifier(trimFuncName.get, Option(db)) - } else { - FunctionIdentifier(fn, Option(db)) - } - case Seq(fn) => - if (trimFuncName.isDefined) { - FunctionIdentifier(trimFuncName.get, None) - } else { - FunctionIdentifier(fn, None) - } + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 908c5ba16340b..1c3275bc9d5da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -426,8 +426,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrim(Seq("a@#( )", s)), "花花世界花花", create_row("aa()花花世界花花@ #")) checkEvaluation(StringTrim(Seq("花 ", Literal("花trim"))), "trim", create_row(" abdef ")) // scalastyle:on - checkEvaluation(StringTrim(Seq((Literal("a")), - (Literal.create(null, StringType)))), null) + checkEvaluation(StringTrim(Seq(Literal("a"), Literal.create(null, StringType))), null) } test("LTRIM") { @@ -450,8 +449,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(Seq("a花界", s)), "世界花花aa", create_row("aa花花世界花花aa")) checkEvaluation(StringTrimLeft(Seq("a世界", s)), "花花世界花花", create_row("花花世界花花")) // scalastyle:on - checkEvaluation(StringTrimLeft(Seq((Literal("a")), - (Literal.create(null, StringType)))), null) + checkEvaluation(StringTrimLeft(Seq(Literal("a"), Literal.create(null, StringType))), null) } test("RTRIM") { @@ -475,8 +473,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimRight(Seq("花", s)), "", create_row("花花花花")) checkEvaluation(StringTrimRight(Seq("花 界b@", s)), " 花花世", create_row(" 花花世 b界@花花 ")) // scalastyle:on - checkEvaluation(StringTrimRight(Seq((Literal("a")), - (Literal.create(null, StringType)))), null) + checkEvaluation(StringTrimRight(Seq(Literal("a"), Literal.create(null, StringType))), null) } test("FORMAT") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index aa8f7e6a90ea9..0ae13c4a4e141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2498,7 +2498,7 @@ object functions { * @group string_funcs * @since 2.2.0 */ - def trim(trimString: String, e: Column): Column = + def trim(e: Column, trimString: String): Column = withExpr { StringTrim(Seq(Literal(trimString), e.expr))} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 21f209c89b870..3d76b9ac33e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -168,15 +168,15 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("example ", " example", "example")) checkAnswer( - df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim("e", $"c")), + df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim($"c", "e")), Row("xample", "exampl", "xampl")) checkAnswer( - df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim("elxp", $"c")), + df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim($"c", "elxp")), Row("ample", "exa", "am")) checkAnswer( - df.select(trim("xyz", $"c")), + df.select(trim($"c", "xyz")), Row("example")) checkAnswer( From 1f3d11e3210972414234715cf18985cd45a9bd8e Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 1 Jun 2017 15:37:46 -0700 Subject: [PATCH 12/21] add missing semicolon --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d8e261c1c97ef..51180576c7e13 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -587,7 +587,7 @@ case class StringTrim(children: Seq[Expression]) s"UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});" } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false + boolean ${ev.isNull} = false; $getTrimFunction if (${ev.value} == null) { ${ev.isNull} = true; @@ -656,7 +656,7 @@ case class StringTrimLeft(children: Seq[Expression]) } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false + boolean ${ev.isNull} = false; $getTrimLeftFunction if (${ev.value} == null) { ${ev.isNull} = true; @@ -724,7 +724,7 @@ case class StringTrimRight(children: Seq[Expression]) s"UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});" } ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false + boolean ${ev.isNull} = false; $getTrimRightFunction if (${ev.value} == null) { ${ev.isNull} = true; From e159b904212d122a54a4a994664e259df65a04ca Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 9 Jun 2017 16:43:00 -0700 Subject: [PATCH 13/21] address comments --- .../apache/spark/unsafe/types/UTF8String.java | 6 +-- .../spark/unsafe/types/UTF8StringSuite.java | 9 ++-- .../expressions/stringExpressions.scala | 45 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 2 +- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index dcdf7d6389f93..edf0270713af5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -511,7 +511,7 @@ public UTF8String trim() { } /** - * Removes the given source string starting from both ends + * Based on the given trim string, trim this string starting from both ends * This method searches for each character in the source string, removes the character if it is found * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight. * It returns a new string in which both ends trim characters have been removed. @@ -534,7 +534,7 @@ public UTF8String trimLeft() { } /** - * Removes the given source string from the left end + * Based on the given trim string, trim this string starting from left end * This method searches each character in the source string starting from the left end, removes the character if it * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string @@ -580,7 +580,7 @@ public UTF8String trimRight() { } /** - * Removes the given source string from the right end + * Based on the given trim string, trim this string starting from right end * This method searches each character in the source string starting from the right end, removes the character if it * is in the trim string, stops at the first character which is not in the trim string, returns the new string. * @param trimString the trim character string diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 98fa1f9ec22eb..f0860018d5642 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -732,7 +732,7 @@ public void testToLong() throws IOException { } @Test - public void trim() { + public void trimBothWithTrimString() { assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); @@ -750,7 +750,10 @@ public void trim() { } @Test - public void trimLeft() { + public void trimLeftWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); @@ -768,7 +771,7 @@ public void trimLeft() { } @Test - public void trimRight() { + public void trimRightWithTrimString() { assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 51180576c7e13..b9f1e2d717686 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -504,8 +504,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } -trait String2TrimExpression extends ImplicitCastInputTypes { - self: Expression => +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -526,9 +525,12 @@ trait String2TrimExpression extends ImplicitCastInputTypes { } /** - * A function that takes a character string, removes the leading and/or trailing characters matching with the characters - * in the trim string, returns the new string. If BOTH and trimStr keywords are not specified, it defaults to remove - * space character from both ends. + * A function that takes a character string, removes the leading and trailing characters matching with the characters + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function * searches for each character in the source string, removes the characters from the source string until it * encounters the first non-match character. @@ -551,18 +553,15 @@ trait String2TrimExpression extends ImplicitCastInputTypes { parkSQ """) case class StringTrim(children: Seq[Expression]) - extends Expression with String2TrimExpression { + extends String2TrimExpression { require(children.size <= 2 && children.nonEmpty, s"$prettyName requires at least one argument and no more than two.") override def prettyName: String = "trim" - // trim function can take one or two arguments. - // Specify one child, it is for the trim space function. - // Specify the two children, it is for the trim function with BOTH option. override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]).reverse if (inputs(0) != null) { if (children.size == 1) { return inputs(0).trim() @@ -580,7 +579,7 @@ case class StringTrim(children: Seq[Expression]) val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" - } + }.reverse val getTrimFunction = if (children.size == 1) { s"UTF8String ${ev.value} = ${inputs(0)}.trim();" } else { @@ -597,8 +596,11 @@ case class StringTrim(children: Seq[Expression]) } /** - * A function that trims the characters from left end for a given string, If LEADING and trimStr keywords are not - * specified, it defaults to remove space character from the left end. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. * trimStr: the function removes any characters from the left end of the source string which matches with the characters * from trimStr, it stops at the first non-match character. * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. @@ -619,16 +621,13 @@ case class StringTrim(children: Seq[Expression]) arkSQLS """) case class StringTrimLeft(children: Seq[Expression]) - extends Expression with String2TrimExpression { + extends String2TrimExpression { require (children.size <= 2 && children.nonEmpty, "$prettyName requires at least one argument and no more than two.") override def prettyName: String = "ltrim" - // ltrim function can take one or two arguments. - // Specify one child, it is for the ltrim space function. - // Specify the two children, it is for the trim function with option LEADING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { @@ -666,8 +665,11 @@ case class StringTrimLeft(children: Seq[Expression]) } /** - * A function that trims the characters from right end for a given string, If TRAILING and trimStr keywords are not - * specified, it defaults to remove space character from the right end. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. * trimStr: the function removes any characters from the right end of source string which matches with the characters * from trimStr, it stops at the first non-match character. * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. @@ -688,16 +690,13 @@ case class StringTrimLeft(children: Seq[Expression]) SSpark """) case class StringTrimRight(children: Seq[Expression]) - extends Expression with String2TrimExpression { + extends String2TrimExpression { require (children.size <= 2 && children.nonEmpty, "$prettyName requires at least one argument and no more than two.") override def prettyName: String = "rtrim" - // rtrim function can take one or two arguments. - // Specify one child, it is for the rtrim space function. - // Specify the two children, it is for the trim function with option TRAILING. override def eval(input: InternalRow): Any = { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) if (inputs(0) != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7dc82159ae746..5c98a37975bb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1208,7 +1208,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create a function name LTRIM for TRIM(Leading), RTRIM for TRIM(Trailing), TRIM for TRIM(BOTH), - * otherwise, returnthe original funcID. + * otherwise, return the original function identifier. */ private def replaceTrimFunction(funcID: FunctionIdentifier, ctx: FunctionCallContext) : FunctionIdentifier = { From 3def573f50c284006d573576b5b88281b0392db8 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Sun, 16 Jul 2017 22:44:17 -0700 Subject: [PATCH 14/21] change the case class for trims --- .../apache/spark/unsafe/types/UTF8String.java | 6 +- .../expressions/stringExpressions.scala | 228 ++++++++++++------ .../expressions/StringExpressionsSuite.scala | 97 ++++---- .../org/apache/spark/sql/functions.scala | 12 +- 4 files changed, 215 insertions(+), 128 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index edf0270713af5..21a3919e87103 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -518,7 +518,11 @@ public UTF8String trim() { * @param trimString the trim character string */ public UTF8String trim(UTF8String trimString) { - return trimLeft(trimString).trimRight(trimString); + if (trimString != null) { + return trimLeft(trimString).trimRight(trimString); + } else { + return null; + } } public UTF8String trimLeft() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b9f1e2d717686..0a6a5b42f12b9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -524,6 +524,11 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { } } +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** * A function that takes a character string, removes the leading and trailing characters matching with the characters * in the trim string, returns the new string. @@ -552,49 +557,76 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); parkSQ """) -case class StringTrim(children: Seq[Expression]) +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) extends String2TrimExpression { - require(children.size <= 2 && children.nonEmpty, - s"$prettyName requires at least one argument and no more than two.") + def this (srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]).reverse - if (inputs(0) != null) { - if (children.size == 1) { - return inputs(0).trim() - } else if (inputs(1) != null) { - return inputs(1).trim(inputs(0)) + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trim() } } null } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && !children(0).isInstanceOf[Literal]) { - throw new AnalysisException(s"The trimming parameter should be Literal.")} - val evals = children.map(_.genCode(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.value}" - }.reverse - val getTrimFunction = if (children.size == 1) { - s"UTF8String ${ev.value} = ${inputs(0)}.trim();" + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + } + """.stripMargin) } else { - s"UTF8String ${ev.value} = ${inputs(1)}.trim(${inputs(0)});" - } - ev.copy(evals.map(_.code).mkString("\n") + s""" + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" boolean ${ev.isNull} = false; - $getTrimFunction - if (${ev.value} == null) { + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { ${ev.isNull} = true; + } else { + $getTrimFunction } - """) + """.stripMargin) + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** * A function that trims the characters from left end for a given string. * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim @@ -620,50 +652,76 @@ case class StringTrim(children: Seq[Expression]) > SELECT _FUNC_('Sp', 'SSparkSQLS'); arkSQLS """) -case class StringTrimLeft(children: Seq[Expression]) +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) extends String2TrimExpression { - require (children.size <= 2 && children.nonEmpty, - "$prettyName requires at least one argument and no more than two.") + def this(srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - if (inputs(0) != null) { - if (children.size == 1) { - return inputs(0).trimLeft() - } else if (inputs(1) != null) { - return inputs(1).trimLeft(inputs(0)) + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trimLeft() } } null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && !children(0).isInstanceOf[Literal]) { - throw new AnalysisException(s"The trimming parameter should be Literal.")} - val evals = children.map(_.genCode(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.value}" - } - val getTrimLeftFunction = if (children.size == 1) { - s"UTF8String ${ev.value} = ${inputs(0)}.trimLeft();" + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""".stripMargin) } else { - s"UTF8String ${ev.value} = ${inputs(1)}.trimLeft(${inputs(0)});" + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + } + """.stripMargin ) } - - ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - $getTrimLeftFunction - if (${ev.value} == null) { - ${ev.isNull} = true; - } - """) } } +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) +} + /** * A function that trims the characters from right end for a given string. * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The @@ -689,46 +747,68 @@ case class StringTrimLeft(children: Seq[Expression]) > SELECT _FUNC_('LQSa', 'SSparkSQLS'); SSpark """) -case class StringTrimRight(children: Seq[Expression]) +case class StringTrimRight( + srcStr: Expression, + trimStr: Option[Expression] = None) extends String2TrimExpression { - require (children.size <= 2 && children.nonEmpty, - "$prettyName requires at least one argument and no more than two.") + def this(srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "rtrim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - if (inputs(0) != null) { - if (children.size == 1) { - return inputs(0).trimRight() - } else if (inputs(1) != null) { - return inputs(1).trimRight(inputs(0)) + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trimRight() } } null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (children.size == 2 && !children(0).isInstanceOf[Literal]) { - throw new AnalysisException(s"The trimming parameter should be Literal.")} - val evals = children.map(_.genCode(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.value}" - } - val getTrimRightFunction = if (children.size == 1) { - s"UTF8String ${ev.value} = ${inputs(0)}.trimRight();" + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(); + }""".stripMargin) } else { - s"UTF8String ${ev.value} = ${inputs(1)}.trimRight(${inputs(0)});" + val trimString = evals(1) + val getTrimRightFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimRightFunction + } + """.stripMargin ) } - ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - $getTrimRightFunction - if (${ev.value} == null) { - ${ev.isNull} = true; - } - """) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1c3275bc9d5da..18ef4bc37c2b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -407,73 +407,76 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("TRIM") { val s = 'a.string.at(0) - checkEvaluation(StringTrim(Seq(Literal(" aa "))), "aa", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("a", Literal("aa"))), "", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("ab cd", Literal(" aabbtrimccc"))), "trim", create_row("bdef")) - checkEvaluation(StringTrim(Seq("a.,@<>", Literal("a@>.,>"))), " ", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq(s)), "abdef", create_row(" abdef ")) - checkEvaluation(StringTrim(Seq("abd", s)), "ef", create_row("abdefa")) - checkEvaluation(StringTrim(Seq("a", s)), "bdef", create_row("aaabdefaaaa")) - checkEvaluation(StringTrim(Seq("SLSQ", s)), "park", create_row("SSparkSQLS")) + checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim("aa", "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrim(Literal(" aabbtrimccc"), "ab cd"), "trim", create_row("bdef")) + checkEvaluation(StringTrim(Literal("a@>.,>"), "a.,@<>"), " ", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + checkEvaluation(StringTrim(s, "abd"), "ef", create_row("abdefa")) + checkEvaluation(StringTrim(s, "a"), "bdef", create_row("aaabdefaaaa")) + checkEvaluation(StringTrim(s, "SLSQ"), "park", create_row("SSparkSQLS")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkEvaluation(StringTrim(Seq(s)), "花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(Seq("花世界", s)), "", create_row("花花世界花花")) - checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花花世界花花")) - checkEvaluation(StringTrim(Seq("花 ", s)), "世界", create_row(" 花 花 世界 花 花 ")) - checkEvaluation(StringTrim(Seq("a花世", s)), "界", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrim(Seq("a@#( )", s)), "花花世界花花", create_row("aa()花花世界花花@ #")) - checkEvaluation(StringTrim(Seq("花 ", Literal("花trim"))), "trim", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s, "花世界"), "", create_row("花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花 花 世界 花 花 ")) + checkEvaluation(StringTrim(s, "a花世"), "界", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(s, "a@#( )"), "花花世界花花", create_row("aa()花花世界花花@ #")) + checkEvaluation(StringTrim(Literal("花trim"), "花 "), "trim", create_row(" abdef ")) // scalastyle:on - checkEvaluation(StringTrim(Seq(Literal("a"), Literal.create(null, StringType))), null) + checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null) } test("LTRIM") { val s = 'a.string.at(0) - checkEvaluation(StringTrimLeft(Seq(Literal(" aa "))), "aa ", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("a", Literal("aa"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("a ", Literal("aa "))), "", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("ab", Literal("aabbcaaaa"))), "caaaa", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq(s)), "abdef ", create_row(" abdef ")) - checkEvaluation(StringTrimLeft(Seq("a", s)), "bdefa", create_row("abdefa")) - checkEvaluation(StringTrimLeft(Seq("a ", s)), "bdefaaaa", create_row(" aaabdefaaaa")) - checkEvaluation(StringTrimLeft(Seq("Spk", s)), "arkSQLS", create_row("SSparkSQLS")) + checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa "), "a "), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aabbcaaaa"), "ab"), "caaaa", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s, "a"), "bdefa", create_row("abdefa")) + checkEvaluation(StringTrimLeft(s, "a "), "bdefaaaa", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimLeft(s, "Spk"), "arkSQLS", create_row("SSparkSQLS")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkEvaluation(StringTrimLeft(Seq(s)), "花花世界 ", create_row(" 花花世界 ")) - checkEvaluation(StringTrimLeft(Seq("花", s)), "世界花花", create_row("花花世界花花")) - checkEvaluation(StringTrimLeft(Seq("花 世", s)), "界花花", create_row(" 花花世界花花")) - checkEvaluation(StringTrimLeft(Seq("花", s)), "a花花世界花花 ", create_row("a花花世界花花 ")) - checkEvaluation(StringTrimLeft(Seq("a花界", s)), "世界花花aa", create_row("aa花花世界花花aa")) - checkEvaluation(StringTrimLeft(Seq("a世界", s)), "花花世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s, "花"), "世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花 世"), "界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花"), "a花花世界花花 ", create_row("a花花世界花花 ")) + checkEvaluation(StringTrimLeft(s, "a花界"), "世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(s, "a世界"), "花花世界花花", create_row("花花世界花花")) // scalastyle:on - checkEvaluation(StringTrimLeft(Seq(Literal("a"), Literal.create(null, StringType))), null) + checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null) + checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null) } test("RTRIM") { val s = 'a.string.at(0) - checkEvaluation(StringTrimRight(Seq(Literal(" aa "))), " aa", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a", Literal("a"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("ab", Literal("ab"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a %", Literal("aabbaaaa %"))), "aabb", create_row("def")) - checkEvaluation(StringTrimRight(Seq(s)), " abdef", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a", s)), "abdef", create_row("abdefa")) - checkEvaluation(StringTrimRight(Seq("abf de", s)), "", create_row(" aaabdefaaaa")) - checkEvaluation(StringTrimRight(Seq("S*&", s)), "SSparkSQL", create_row("SSparkSQLS*")) + checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("a"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("ab"), "ab"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("aabbaaaa %"), "a %"), "aabb", create_row("def")) + checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s, "a"), "abdef", create_row("abdefa")) + checkEvaluation(StringTrimRight(s, "abf de"), "", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimRight(s, "S*&"), "SSparkSQL", create_row("SSparkSQLS*")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkEvaluation(StringTrimRight(Seq("花", Literal("a"))), "a", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("a", Literal("花"))), "花", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq("界花世", Literal("花花世界"))), "", create_row(" abdef ")) - checkEvaluation(StringTrimRight(Seq(s)), " 花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrimRight(Seq("花a#", s)), "花花世界", create_row("花花世界花花###aa花")) - checkEvaluation(StringTrimRight(Seq("花", s)), "", create_row("花花花花")) - checkEvaluation(StringTrimRight(Seq("花 界b@", s)), " 花花世", create_row(" 花花世 b界@花花 ")) + checkEvaluation(StringTrimRight(Literal("a"), "花"), "a", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花"), "a"), "花", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花花世界"), "界花世"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimRight(s, "花a#"), "花花世界", create_row("花花世界花花###aa花")) + checkEvaluation(StringTrimRight(s, "花"), "", create_row("花花花花")) + checkEvaluation(StringTrimRight(s, "花 界b@"), " 花花世", create_row(" 花花世 b界@花花 ")) // scalastyle:on - checkEvaluation(StringTrimRight(Seq(Literal("a"), Literal.create(null, StringType))), null) + checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null) } test("FORMAT") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0ae13c4a4e141..a8dda7f53c809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2331,7 +2331,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = withExpr {StringTrimLeft(Seq(e.expr))} + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr)} /** * Trim the specified character string from left end for the specified string column. @@ -2339,7 +2339,7 @@ object functions { * @since 2.2.0 */ def ltrim(e: Column, trimString: String): Column = - withExpr { StringTrimLeft(Seq(Literal(trimString), e.expr))} + withExpr { StringTrimLeft(e.expr, Literal(trimString))} /** * Extract a specific group matched by a Java regex, from the specified string column. @@ -2416,7 +2416,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = withExpr { StringTrimRight(Seq(e.expr)) } + def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** * Trim the specified character string from right end for the specified string column. @@ -2424,7 +2424,7 @@ object functions { * @since 2.2.0 */ def rtrim(e: Column, trimString: String): Column = - withExpr { StringTrimRight(Seq(Literal(trimString), e.expr))} + withExpr { StringTrimRight(e.expr, Literal(trimString))} /** * Returns the soundex code for the specified expression. @@ -2491,7 +2491,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = withExpr { StringTrim(Seq(e.expr)) } + def trim(e: Column): Column = withExpr { StringTrim(e.expr) } /** * Trim the specified character from both ends for the specified string column. @@ -2499,7 +2499,7 @@ object functions { * @since 2.2.0 */ def trim(e: Column, trimString: String): Column = - withExpr { StringTrim(Seq(Literal(trimString), e.expr))} + withExpr { StringTrim(e.expr, Literal(trimString))} /** * Converts a string column to upper case. From 90600cb17c97df5822a8a66b7dfa9856a08ddc27 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 17 Jul 2017 23:37:48 -0700 Subject: [PATCH 15/21] rebase --- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 4 ++++ .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 4 ++-- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 21a3919e87103..bd12855306723 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -544,6 +544,8 @@ public UTF8String trimLeft() { * @param trimString the trim character string */ public UTF8String trimLeft(UTF8String trimString) { + if (trimString == null) + return null; // the searching byte position in the source string int srchIdx = 0; // the first beginning byte position of a non-matching character @@ -590,6 +592,8 @@ public UTF8String trimRight() { * @param trimString the trim character string */ public UTF8String trimRight(UTF8String trimString) { + if (trimString == null) + return null; int charIdx = 0; // number of characters from the source string int numChars = 0; diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 332dbc9355ff9..d0a54288780ea 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -580,8 +580,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall - | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) trimChar=namedExpression - FROM namedExpression ')' #functionCall + | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression + FROM argument+=expression ')' #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0a6a5b42f12b9..2de3a49ea0550 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -562,7 +562,7 @@ case class StringTrim( trimStr: Option[Expression] = None) extends String2TrimExpression { - def this (srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + def this (trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) def this(srcStr: Expression) = this(srcStr, None) @@ -657,7 +657,7 @@ case class StringTrimLeft( trimStr: Option[Expression] = None) extends String2TrimExpression { - def this(srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) def this(srcStr: Expression) = this(srcStr, None) @@ -752,7 +752,7 @@ case class StringTrimRight( trimStr: Option[Expression] = None) extends String2TrimExpression { - def this(srcStr: Expression, trimStr: Expression) = this(srcStr, Option(trimStr)) + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) def this(srcStr: Expression) = this(srcStr, None) From 88886bfc3febab2bea28604791579bc5a5560b23 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 19 Jul 2017 22:30:26 -0700 Subject: [PATCH 16/21] change the version number to 2.3 --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a8dda7f53c809..05d3662d93604 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2336,7 +2336,7 @@ object functions { /** * Trim the specified character string from left end for the specified string column. * @group string_funcs - * @since 2.2.0 + * @since 2.3.0 */ def ltrim(e: Column, trimString: String): Column = withExpr { StringTrimLeft(e.expr, Literal(trimString))} @@ -2421,7 +2421,7 @@ object functions { /** * Trim the specified character string from right end for the specified string column. * @group string_funcs - * @since 2.2.0 + * @since 2.3.0 */ def rtrim(e: Column, trimString: String): Column = withExpr { StringTrimRight(e.expr, Literal(trimString))} @@ -2496,7 +2496,7 @@ object functions { /** * Trim the specified character from both ends for the specified string column. * @group string_funcs - * @since 2.2.0 + * @since 2.3.0 */ def trim(e: Column, trimString: String): Column = withExpr { StringTrim(e.expr, Literal(trimString))} From 1309bc14cac85bf115eb7162f7e6f3887b8bbfa6 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Sat, 12 Aug 2017 20:16:47 -0700 Subject: [PATCH 17/21] address comments --- .../expressions/stringExpressions.scala | 6 +-- .../sql/catalyst/parser/AstBuilder.scala | 50 ++++++++----------- .../sql/catalyst/parser/PlanParserSuite.scala | 18 +++++++ .../org/apache/spark/sql/functions.scala | 17 ++++--- 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2de3a49ea0550..4060adf215683 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -546,7 +546,7 @@ object StringTrim { _FUNC_(str) - Removes the leading and trailing space characters from `str`. _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` """, - extended = """ + examples = """ Arguments: str - a string expression trimString - the trim string @@ -642,7 +642,7 @@ object StringTrimLeft { _FUNC_(str) - Removes the leading space characters from `str`. _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string """, - extended = """ + examples = """ Arguments: str - a string expression trimStr - the trim string @@ -737,7 +737,7 @@ object StringTrimRight { _FUNC_(str) - Removes the trailing space characters from `str`. _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` """, - extended = """ + examples = """ Arguments: str - a string expression trimStr - the trim string diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5c98a37975bb1..c1c59e03da26e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1179,6 +1179,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a (windowed) Function expression. */ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + def replaceFunctions( + funcID: FunctionIdentifier, + ctx: FunctionCallContext): FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case SqlBaseParser.BOTH => funcID + case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException("Function trim doesn't support with " + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID + } + } // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) @@ -1190,10 +1210,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } - val function = UnresolvedFunction( - replaceTrimFunction(visitFunctionName(ctx.qualifiedName), ctx), - arguments, - isDistinct) + val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx) + val function = UnresolvedFunction(funcId, arguments, isDistinct) // Check if the function is evaluated in a windowed context. @@ -1206,30 +1224,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } - /** - * Create a function name LTRIM for TRIM(Leading), RTRIM for TRIM(Trailing), TRIM for TRIM(BOTH), - * otherwise, return the original function identifier. - */ - private def replaceTrimFunction(funcID: FunctionIdentifier, ctx: FunctionCallContext) - : FunctionIdentifier = { - val opt = ctx.trimOption - if (opt != null) { - if (ctx.qualifiedName.getText.toLowerCase != "trim") { - throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + - s"doesn't support with option ${opt.getText}.", ctx) - } - opt.getType match { - case SqlBaseParser.BOTH => funcID - case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") - case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") - case _ => throw new ParseException(s"Function trim doesn't support with" + - s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) - } - } else { - funcID - } - } - /** * Create a current timestamp/date expression. These are different from regular function because * they do not require the user to specify braces when calling them. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index b0d2fb26a6006..306e6f2cfbd37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -651,4 +651,22 @@ class PlanParserSuite extends AnalysisTest { ) ) } + + test("TRIM function") { + intercept("select ltrim(both 'S' from 'SS abc S'", "missing ')' at ''") + intercept("select rtrim(trailing 'S' from 'SS abc S'", "missing ')' at ''") + + assertEqual( + "SELECT TRIM(BOTH '@$%&( )abc' FROM '@ $ % & ()abc ' )", + OneRowRelation().select('TRIM.function("@$%&( )abc", "@ $ % & ()abc ")) + ) + assertEqual( + "SELECT TRIM(LEADING 'c []' FROM '[ ccccbcc ')", + OneRowRelation().select('ltrim.function("c []", "[ ccccbcc ")) + ) + assertEqual( + "SELECT TRIM(TRAILING 'c&^,.' FROM 'bc...,,,&&&ccc')", + OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 05d3662d93604..c6d0d86384b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2331,15 +2331,16 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr)} + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** * Trim the specified character string from left end for the specified string column. * @group string_funcs * @since 2.3.0 */ - def ltrim(e: Column, trimString: String): Column = - withExpr { StringTrimLeft(e.expr, Literal(trimString))} + def ltrim(e: Column, trimString: String): Column = withExpr { + StringTrimLeft(e.expr, Literal(trimString)) + } /** * Extract a specific group matched by a Java regex, from the specified string column. @@ -2423,8 +2424,9 @@ object functions { * @group string_funcs * @since 2.3.0 */ - def rtrim(e: Column, trimString: String): Column = - withExpr { StringTrimRight(e.expr, Literal(trimString))} + def rtrim(e: Column, trimString: String): Column = withExpr { + StringTrimRight(e.expr, Literal(trimString)) + } /** * Returns the soundex code for the specified expression. @@ -2498,8 +2500,9 @@ object functions { * @group string_funcs * @since 2.3.0 */ - def trim(e: Column, trimString: String): Column = - withExpr { StringTrim(e.expr, Literal(trimString))} + def trim(e: Column, trimString: String): Column = withExpr { + StringTrim(e.expr, Literal(trimString)) + } /** * Converts a string column to upper case. From f7d8c06def6086b5e81fbb8aa98f3085ad10e7a6 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Sun, 27 Aug 2017 23:33:09 -0700 Subject: [PATCH 18/21] fix the function description for arguments --- .../expressions/stringExpressions.scala | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 4060adf215683..97abcc8c52669 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -546,11 +546,13 @@ object StringTrim { _FUNC_(str) - Removes the leading and trailing space characters from `str`. _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` """, - examples = """ + arguments = """ Arguments: - str - a string expression - trimString - the trim string - BOTH, FROM - these are keyword to specify for trim string from both ends of the string + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, + examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL @@ -642,10 +644,13 @@ object StringTrimLeft { _FUNC_(str) - Removes the leading space characters from `str`. _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string """, - examples = """ + arguments = """ Arguments: - str - a string expression - trimStr - the trim string + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, + examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL @@ -737,10 +742,13 @@ object StringTrimRight { _FUNC_(str) - Removes the trailing space characters from `str`. _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` """, - examples = """ + arguments = """ Arguments: - str - a string expression - trimStr - the trim string + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, + examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL From d107f44b34b574c43e0ee15de7622a44ea6cfe52 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 5 Sep 2017 23:36:14 -0700 Subject: [PATCH 19/21] address comments --- .../expressions/stringExpressions.scala | 60 ++++++++----------- .../parser/TableIdentifierParserSuite.scala | 2 +- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 97abcc8c52669..02b45b1a58894 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -511,17 +511,6 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - - override def sql: String = { - if (children.size == 1) { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } else { - val trimSQL = children(0).map(_.sql).mkString(", ") - val tarSQL = children(1).map(_.sql).mkString(", ") - s"$prettyName($trimSQL, $tarSQL)" - } - } } object StringTrim { @@ -539,8 +528,8 @@ object StringTrim { * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function * searches for each character in the source string, removes the characters from the source string until it * encounters the first non-match character. - * BOTH: removes any characters from both ends of the source string that matches characters in the trim string. - */ + * BOTH: removes any character from both ends of the source string that matches characters in the trim string. + */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the leading and trailing space characters from `str`. @@ -577,14 +566,15 @@ case class StringTrim( } override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] - if (srcString != null) { + if (srcString == null) { + null + } else { if (trimStr.isDefined) { return srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { return srcString.trim() } } - null } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -599,8 +589,7 @@ case class StringTrim( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trim(); - } - """.stripMargin) + }""".stripMargin) } else { val trimString = evals(1) val getTrimFunction = @@ -609,17 +598,16 @@ case class StringTrim( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); - }""".stripMargin + }""".stripMargin ev.copy(evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - UTF8String ${ev.value} = null; - if (${srcString.isNull}) { - ${ev.isNull} = true; - } else { - $getTrimFunction - } - """.stripMargin) + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + }""".stripMargin) } } } @@ -637,7 +625,7 @@ object StringTrimLeft { * have two arguments, the first argument contains trimStr, the second argument contains the source string. * trimStr: the function removes any characters from the left end of the source string which matches with the characters * from trimStr, it stops at the first non-match character. - * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. + * LEADING: removes any character from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ @@ -676,14 +664,15 @@ case class StringTrimLeft( override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] - if (srcString != null) { + if (srcString == null) { + null + } else { if (trimStr.isDefined) { return srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { return srcString.trimLeft() } } - null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -716,8 +705,7 @@ case class StringTrimLeft( ${ev.isNull} = true; } else { $getTrimLeftFunction - } - """.stripMargin ) + }""".stripMargin ) } } } @@ -735,7 +723,7 @@ object StringTrimRight { * have two arguments, the first argument contains trimStr, the second argument contains the source string. * trimStr: the function removes any characters from the right end of source string which matches with the characters * from trimStr, it stops at the first non-match character. - * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. + * TRAILING: removes any character from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ @@ -774,14 +762,15 @@ case class StringTrimRight( override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] - if (srcString != null) { + if (srcString == null) { + null + } else { if (trimStr.isDefined) { return srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { return srcString.trimRight() } } - null } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -814,8 +803,7 @@ case class StringTrimRight( ${ev.isNull} = true; } else { $getTrimRightFunction - } - """.stripMargin ) + }""".stripMargin ) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 76be6ee3f50bc..cc80a41df998d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", From 6da0f8f50e319e000d1ea1c258e97bac6b2fde01 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 6 Sep 2017 12:26:31 -0700 Subject: [PATCH 20/21] fix style --- .../expressions/stringExpressions.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 02b45b1a58894..0a8922510d745 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -589,7 +589,7 @@ case class StringTrim( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trim(); - }""".stripMargin) + }""") } else { val trimString = evals(1) val getTrimFunction = @@ -598,7 +598,7 @@ case class StringTrim( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); - }""".stripMargin + }""" ev.copy(evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -607,7 +607,7 @@ case class StringTrim( ${ev.isNull} = true; } else { $getTrimFunction - }""".stripMargin) + }""") } } } @@ -687,7 +687,7 @@ case class StringTrimLeft( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trimLeft(); - }""".stripMargin) + }""") } else { val trimString = evals(1) val getTrimLeftFunction = @@ -696,7 +696,7 @@ case class StringTrimLeft( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); - }""".stripMargin + }""" ev.copy(evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -705,7 +705,7 @@ case class StringTrimLeft( ${ev.isNull} = true; } else { $getTrimLeftFunction - }""".stripMargin ) + }""") } } } @@ -769,8 +769,8 @@ case class StringTrimRight( return srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { return srcString.trimRight() - } } + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -785,7 +785,7 @@ case class StringTrimRight( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trimRight(); - }""".stripMargin) + }""") } else { val trimString = evals(1) val getTrimRightFunction = @@ -794,7 +794,7 @@ case class StringTrimRight( ${ev.isNull} = true; } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); - }""".stripMargin + }""" ev.copy(evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -803,7 +803,7 @@ case class StringTrimRight( ${ev.isNull} = true; } else { $getTrimRightFunction - }""".stripMargin ) + }""") } } } From 79846bfd86c6265ebe7d14906853bfe2ec0467f3 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Sun, 17 Sep 2017 22:09:07 -0700 Subject: [PATCH 21/21] address comments --- .../apache/spark/unsafe/types/UTF8String.java | 8 ++-- .../expressions/stringExpressions.scala | 37 +++++++++---------- .../sql/catalyst/parser/AstBuilder.scala | 34 ++++++++--------- 3 files changed, 37 insertions(+), 42 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index bd12855306723..aba39c548e8e5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -544,8 +544,7 @@ public UTF8String trimLeft() { * @param trimString the trim character string */ public UTF8String trimLeft(UTF8String trimString) { - if (trimString == null) - return null; + if (trimString == null) return null; // the searching byte position in the source string int srchIdx = 0; // the first beginning byte position of a non-matching character @@ -592,8 +591,7 @@ public UTF8String trimRight() { * @param trimString the trim character string */ public UTF8String trimRight(UTF8String trimString) { - if (trimString == null) - return null; + if (trimString == null) return null; int charIdx = 0; // number of characters from the source string int numChars = 0; @@ -604,7 +602,7 @@ public UTF8String trimRight(UTF8String trimString) { // build the position and length array while (charIdx < numBytes) { stringCharPos[numChars] = charIdx; - stringCharLen[numChars]= numBytesForFirstByte(getByte(charIdx)); + stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); charIdx += stringCharLen[numChars]; numChars ++; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0a8922510d745..83de515079eea 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -519,7 +519,7 @@ object StringTrim { } /** - * A function that takes a character string, removes the leading and trailing characters matching with the characters + * A function that takes a character string, removes the leading and trailing characters matching with any character * in the trim string, returns the new string. * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim * function will have one argument, which contains the source string. @@ -553,7 +553,7 @@ case class StringTrim( trimStr: Option[Expression] = None) extends String2TrimExpression { - def this (trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) def this(srcStr: Expression) = this(srcStr, None) @@ -570,9 +570,9 @@ case class StringTrim( null } else { if (trimStr.isDefined) { - return srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { - return srcString.trim() + srcString.trim() } } } @@ -582,7 +582,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -599,8 +599,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString("\n") + - s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -623,7 +622,7 @@ object StringTrimLeft { * function will have one argument, which contains the source string. * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will * have two arguments, the first argument contains trimStr, the second argument contains the source string. - * trimStr: the function removes any characters from the left end of the source string which matches with the characters + * trimStr: the function removes any character from the left end of the source string which matches with the characters * from trimStr, it stops at the first non-match character. * LEADING: removes any character from the left end of the source string that matches characters in the trim string. */ @@ -668,9 +667,9 @@ case class StringTrimLeft( null } else { if (trimStr.isDefined) { - return srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { - return srcString.trimLeft() + srcString.trimLeft() } } } @@ -680,7 +679,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -697,8 +696,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString("\n") + - s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -721,14 +719,14 @@ object StringTrimRight { * rtrim function will have one argument, which contains the source string. * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will * have two arguments, the first argument contains trimStr, the second argument contains the source string. - * trimStr: the function removes any characters from the right end of source string which matches with the characters + * trimStr: the function removes any character from the right end of source string which matches with the characters * from trimStr, it stops at the first non-match character. * TRAILING: removes any character from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( usage = """ _FUNC_(str) - Removes the trailing space characters from `str`. - _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` + _FUNC_(trimStr, str) - Removes the trailing string which contains the characters from the trim string from the `str` """, arguments = """ Arguments: @@ -766,9 +764,9 @@ case class StringTrimRight( null } else { if (trimStr.isDefined) { - return srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) + srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) } else { - return srcString.trimRight() + srcString.trimRight() } } } @@ -778,7 +776,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -795,8 +793,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString("\n") + - s""" + ev.copy(evals.map(_.code).mkString + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c1c59e03da26e..85b492e83446e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1180,25 +1180,25 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { def replaceFunctions( - funcID: FunctionIdentifier, - ctx: FunctionCallContext): FunctionIdentifier = { - val opt = ctx.trimOption - if (opt != null) { - if (ctx.qualifiedName.getText.toLowerCase != "trim") { - throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + - s"doesn't support with option ${opt.getText}.", ctx) - } - opt.getType match { - case SqlBaseParser.BOTH => funcID - case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") - case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") - case _ => throw new ParseException("Function trim doesn't support with " + - s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) - } - } else { - funcID + funcID: FunctionIdentifier, + ctx: FunctionCallContext): FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) } + opt.getType match { + case SqlBaseParser.BOTH => funcID + case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException("Function trim doesn't support with " + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID } + } // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)