From 1a2e6110478642f30487e569f2f3645ef058bc78 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 01:39:08 -0700 Subject: [PATCH 1/5] [SPARK-9157][SQL] codegen substring --- .../expressions/stringOperations.scala | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5f8ac716f79a1..6b752a110a520 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -584,7 +584,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -593,12 +593,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") - } - if (str.dataType == BinaryType) str.dataType else StringType - } + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) @@ -628,24 +623,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def eval(input: InternalRow): Any = { val string = str.eval(input) - val po = pos.eval(input) - val ln = len.eval(input) - - if ((string == null) || (po == null) || (ln == null)) { - null - } else { - val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - string match { - case ba: Array[Byte] => - val (st, end) = slicePos(start, length, () => ba.length) - ba.slice(st, end) - case s: UTF8String => + if (string != null) { + val po = pos.eval(input) + if (po != null) { + val ln = len.eval(input) + if (ln != null) { + val start = po.asInstanceOf[Int] + val length = ln.asInstanceOf[Int] + val s = string.asInstanceOf[UTF8String] val (st, end) = slicePos(start, length, () => s.numChars()) s.substring(st, end) + } else { + null + } + } else { + null } + } else { + null } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val strGen = str.gen(ctx) + val posGen = pos.gen(ctx) + val lenGen = len.gen(ctx) + + val start = ctx.freshName("start") + val end = ctx.freshName("end") + + s""" + ${strGen.code} + boolean ${ev.isNull} = ${strGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${posGen.code} + if (!${posGen.isNull}) { + ${lenGen.code} + if (!${lenGen.isNull}) { + int $start = (${posGen.primitive} > 0) ? ${posGen.primitive} - 1 : + ((${posGen.primitive} < 0) ? ${strGen.primitive}.numChars() + ${posGen.primitive} : + 0); + int $end = (${lenGen.primitive} == Integer.MAX_VALUE) ? Integer.MAX_VALUE : + $start + ${lenGen.primitive}; + ${ev.primitive} = ${strGen.primitive}.substring($start, $end); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } } /** From 18c3576446f4807dba92e388b4ec4c62601bdb04 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 16:31:11 -0700 Subject: [PATCH 2/5] [SPARK-9157][SQL] remove slice pos --- .../expressions/stringOperations.scala | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6b752a110a520..de21c7aa60e2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -599,40 +599,31 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def children: Seq[Expression] = str :: pos :: len :: Nil - @inline - def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { + override def eval(input: InternalRow): Any = { + // Information regarding the pos calculation: // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. - - val start = startPos match { - case pos if pos > 0 => pos - 1 - case neg if neg < 0 => length() + neg - case _ => 0 - } - - val end = sliceLen match { - case max if max == Integer.MAX_VALUE => max - case x => start + x - } - - (start, end) - } - - override def eval(input: InternalRow): Any = { val string = str.eval(input) if (string != null) { val po = pos.eval(input) if (po != null) { val ln = len.eval(input) if (ln != null) { - val start = po.asInstanceOf[Int] val length = ln.asInstanceOf[Int] val s = string.asInstanceOf[UTF8String] - val (st, end) = slicePos(start, length, () => s.numChars()) - s.substring(st, end) + val pos = po.asInstanceOf[Int] + val start = { + if (pos > 0) { + pos - 1 + } else { + if (pos < 0) s.numChars() + pos else 0 + } + } + val end = if (length == Integer.MAX_VALUE) Integer.MAX_VALUE else start + length + s.substring(start, end) } else { null } From 60732ea37e109fcba2f2f213712a9b7ed31e5e45 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 18:43:01 -0700 Subject: [PATCH 3/5] [SPARK-9157] created substringSQL in UTF8String --- .../expressions/stringOperations.scala | 34 ++++---------- .../apache/spark/unsafe/types/UTF8String.java | 12 +++++ .../spark/unsafe/types/UTF8StringSuite.java | 47 +++++++++++++------ 3 files changed, 55 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index de21c7aa60e2a..4fcf43834568c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -606,24 +606,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. - val string = str.eval(input) - if (string != null) { - val po = pos.eval(input) - if (po != null) { - val ln = len.eval(input) - if (ln != null) { - val length = ln.asInstanceOf[Int] - val s = string.asInstanceOf[UTF8String] - val pos = po.asInstanceOf[Int] - val start = { - if (pos > 0) { - pos - 1 - } else { - if (pos < 0) s.numChars() + pos else 0 - } - } - val end = if (length == Integer.MAX_VALUE) Integer.MAX_VALUE else start + length - s.substring(start, end) + val stringEval = str.eval(input) + if (stringEval != null) { + val posEval = pos.eval(input) + if (posEval != null) { + val lenEval = len.eval(input) + if (lenEval != null) { + stringEval.asInstanceOf[UTF8String] + .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) } else { null } @@ -652,12 +642,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (!${posGen.isNull}) { ${lenGen.code} if (!${lenGen.isNull}) { - int $start = (${posGen.primitive} > 0) ? ${posGen.primitive} - 1 : - ((${posGen.primitive} < 0) ? ${strGen.primitive}.numChars() + ${posGen.primitive} : - 0); - int $end = (${lenGen.primitive} == Integer.MAX_VALUE) ? Integer.MAX_VALUE : - $start + ${lenGen.primitive}; - ${ev.primitive} = ${strGen.primitive}.substring($start, $end); + ${ev.primitive} = ${strGen.primitive} + .substringSQL(${posGen.primitive}, ${lenGen.primitive}); } else { ${ev.isNull} = true; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3eecd657e6ef9..a579e6ddbbba8 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -153,6 +153,18 @@ public UTF8String substring(final int start, final int until) { return fromBytes(bytes); } + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0); + int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length; + return substring(start, end); + } + /** * Returns whether this contains `substring` or not. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7d0c49e2fb84c..db5e93dc5c104 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -112,26 +112,26 @@ public void concatWsTest() { fromString(""), concatWs(sep, fromString(""))); assertEquals( - fromString("ab"), - concatWs(sep, fromString("ab"))); + fromString("ab"), + concatWs(sep, fromString("ab"))); assertEquals( - fromString("a哈哈b"), - concatWs(sep, fromString("a"), fromString("b"))); + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); assertEquals( - fromString("a哈哈b哈哈c"), - concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); assertEquals( - fromString("a哈哈c"), - concatWs(sep, fromString("a"), null, fromString("c"))); + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); assertEquals( - fromString("a"), - concatWs(sep, fromString("a"), null, null)); + fromString("a"), + concatWs(sep, fromString("a"), null, null)); assertEquals( - fromString(""), - concatWs(sep, null, null, null)); + fromString(""), + concatWs(sep, null, null, null)); assertEquals( - fromString("数据哈哈砖头"), - concatWs(sep, fromString("数据"), fromString("砖头"))); + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test @@ -262,6 +262,25 @@ public void pad() { fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } + + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), fromString("")); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } @Test public void levenshteinDistance() { From 44e89f82e0b70f5a4bc830a2537dd5976ce78b34 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 18:45:24 -0700 Subject: [PATCH 4/5] [SPARK-9157] use EMPTY_UTF8 --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 37733e7e53e09..853497e5c32fd 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -285,7 +285,7 @@ public void substringSQL() { assertEquals(e.substringSQL(1, 6), fromString("exampl")); assertEquals(e.substringSQL(2, 100), fromString("xample")); assertEquals(e.substringSQL(0, 0), fromString("")); - assertEquals(e.substringSQL(100, 4), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); From e65e3e977ed1c22322cf7d79d2e4bb72a36b3e7f Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 18:46:45 -0700 Subject: [PATCH 5/5] [SPARK-9157] indent fix --- .../spark/unsafe/types/UTF8StringSuite.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 853497e5c32fd..e2a5628ff4d93 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -121,26 +121,26 @@ public void concatWsTest() { EMPTY_UTF8, concatWs(sep, EMPTY_UTF8)); assertEquals( - fromString("ab"), - concatWs(sep, fromString("ab"))); + fromString("ab"), + concatWs(sep, fromString("ab"))); assertEquals( - fromString("a哈哈b"), - concatWs(sep, fromString("a"), fromString("b"))); + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); assertEquals( - fromString("a哈哈b哈哈c"), - concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); assertEquals( - fromString("a哈哈c"), - concatWs(sep, fromString("a"), null, fromString("c"))); + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); assertEquals( - fromString("a"), - concatWs(sep, fromString("a"), null, null)); + fromString("a"), + concatWs(sep, fromString("a"), null, null)); assertEquals( EMPTY_UTF8, concatWs(sep, null, null, null)); assertEquals( - fromString("数据哈哈砖头"), - concatWs(sep, fromString("数据"), fromString("砖头"))); + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); } @Test