Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9154][SQL] codegen StringFormat #7546

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {

require(children.nonEmpty, "printf() should take at least 1 argument")

Expand All @@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus Is this what you proposed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works. I think StringType :: List.fill(children.size - 1)(AnyDataType) might be a little more clear.



override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
Expand All @@ -501,6 +505,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val pattern = children.head.gen(ctx)

val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
val argListCode = argListGen.map(_._2.code + "\n")

val argListString = argListGen.foldLeft("")((s, v) => {
val nullSafeString =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
}
s + "," + nullSafeString
})

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
$form.format(${pattern.primitive}.toString() $argListString);
${ev.primitive} = UTF8String.fromString($sb.toString());
}
"""
}

override def prettyName: String = "printf"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing: would you mind rewriting these to avoid the use of row and just use literals? using a row makes the test cases harder to follow since you have to look in multiple places to understand what is going on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up!

val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")

checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
checkEvaluation(StringFormat(f, d1, s1), null, row2)
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}

test("INSTR") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
df2.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}

test("string instr function") {
Expand Down