From 086caba76f646f86840a2cee325188895ab42c8f Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 09:29:03 -0700 Subject: [PATCH 1/4] [SPARK-9154][SQL] codegen string format --- .../expressions/stringOperations.scala | 24 +++++++++++++++++++ .../expressions/StringExpressionsSuite.scala | 1 + 2 files changed, 25 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5f8ac716f79a1..0ad5d0317db14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -501,6 +501,30 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val pattern = children.head.gen(ctx) + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + val argListString = argListGen.foldLeft("")((s, x) => s + s", ${x._2.primitive}" + (if (!ctx.isPrimitiveType(x._1)) ".toString()" else "")) + val form = ctx.freshName("formatter") + val formatter = classOf[java.util.Formatter].getName + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + + s""" + ${pattern.code} + boolean ${ev.isNull} = ${pattern.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${argListCode.mkString} + $stringBuffer $sb = new $stringBuffer(); + $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $form.format(${pattern.primitive}.toString() $argListString); + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + } + override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 96f433be8b065..80cf44c7a4104 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -361,6 +361,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + println(StringFormat(f, d1, s1).eval(row1)) checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) checkEvaluation(StringFormat(f, d1, s1), null, row2) } From cd8322bc4e6c15cd9911363c4596eba1a935fcdd Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 14:40:30 -0700 Subject: [PATCH 2/4] [SPARK-9154][SQL] codegen string format --- .../spark/sql/catalyst/expressions/stringOperations.scala | 8 +++++--- .../sql/catalyst/expressions/StringExpressionsSuite.scala | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 0ad5d0317db14..050fe22197ce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -503,9 +503,11 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val pattern = children.head.gen(ctx) - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) - val argListCode = argListGen.map(_._2.code + "\n") - val argListString = argListGen.foldLeft("")((s, x) => s + s", ${x._2.primitive}" + (if (!ctx.isPrimitiveType(x._1)) ".toString()" else "")) + + val argListGen = children.tail.map(_.gen(ctx)) + val argListCode = argListGen.map(_.code + "\n") + val argListString = argListGen.foldLeft("")((s, v) => s + s", ${v.primitive}") + val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 80cf44c7a4104..93bb538663cec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -353,7 +353,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("FORMAT") { val f = 'f.string.at(0) val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) + val s1 = 's.string.at(2) val row1 = create_row("aa%d%s", 12, "cc") val row2 = create_row(null, 12, "cc") @@ -361,7 +361,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) - println(StringFormat(f, d1, s1).eval(row1)) checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) checkEvaluation(StringFormat(f, d1, s1), null, row2) } From 10b4de88c817a474b7b0a83d948cb86927638775 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 14:42:28 -0700 Subject: [PATCH 3/4] [SPARK-9154][SQL] codegen removed fallback trait --- .../spark/sql/catalyst/expressions/stringOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 050fe22197ce7..81979ab5d2dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with CodegenFallback { +case class StringFormat(children: Expression*) extends Expression { require(children.nonEmpty, "printf() should take at least 1 argument") From a943d3e60649f4267e40376c0bb1ff30ae024436 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 23:26:58 -0700 Subject: [PATCH 4/4] [SPARK-9154] implicit input cast, added tests for null, support for null primitives --- .../expressions/stringOperations.scala | 24 +++++++++++++++---- .../expressions/StringExpressionsSuite.scala | 18 +++++++------- .../spark/sql/StringFunctionsSuite.scala | 10 ++++++++ 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 81979ab5d2dce..08b17420d6cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression { +case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression { private def format: Expression = children(0) private def args: Seq[Expression] = children.tail + override def inputTypes: Seq[AbstractDataType] = + children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType) + + override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -504,15 +508,25 @@ case class StringFormat(children: Expression*) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val pattern = children.head.gen(ctx) - val argListGen = children.tail.map(_.gen(ctx)) - val argListCode = argListGen.map(_.code + "\n") - val argListString = argListGen.foldLeft("")((s, v) => s + s", ${v.primitive}") + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 93bb538663cec..63d09fd6375cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.string.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index fe4de8d8b855f..274ec8f4675e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) + + val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df2.selectExpr("printf(a, b, c)"), + Row("aa123cc")) } test("string instr function") {