Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[SPARK-9154] [SQL] codegen StringFormat" #7570

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {

require(children.nonEmpty, "printf() should take at least 1 argument")

Expand All @@ -536,10 +536,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail

override def inputTypes: Seq[AbstractDataType] =
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)


override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
Expand All @@ -555,42 +551,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val pattern = children.head.gen(ctx)

val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
val argListCode = argListGen.map(_._2.code + "\n")

val argListString = argListGen.foldLeft("")((s, v) => {
val nullSafeString =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
}
s + "," + nullSafeString
})

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${argListCode.mkString}
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
$form.format(${pattern.primitive}.toString() $argListString);
${ev.primitive} = UTF8String.fromString($sb.toString());
}
"""
}

override def prettyName: String = "printf"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("FORMAT") {
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
val f = 'f.string.at(0)
val d1 = 'd.int.at(1)
val s1 = 's.int.at(2)

val row1 = create_row("aa%d%s", 12, "cc")
val row2 = create_row(null, 12, "cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)

checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
checkEvaluation(
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
checkEvaluation(StringFormat(f, d1, s1), null, row2)
}

test("INSTR") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))

val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")

checkAnswer(
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
Row("aa123cc", "aa123cc"))

checkAnswer(
df2.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
}

test("string instr function") {
Expand Down