Skip to content

Commit

Permalink
format_number udf should take user specifed format as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Apr 9, 2018
1 parent b02e76c commit 202fa3d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,8 @@ case class Encode(value: Expression, charset: Expression)
Examples:
> SELECT _FUNC_(12332.123456, 4);
12,332.1235
> SELECT _FUNC_(12332.123456, '##################.###');
12332.123
""")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
Expand All @@ -2030,14 +2032,20 @@ case class FormatNumber(x: Expression, d: Expression)
override def right: Expression = d
override def dataType: DataType = StringType
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(NumericType, TypeCollection(IntegerType, StringType))

private val defaultFormat = "#,###,###,###,###,###,##0"

// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
// serialization (numberFormat has not been updated for dValue = 0).
@transient
private var lastDValue: Option[Int] = None
private var lastDIntValue: Option[Int] = None

@transient
private var lastDStringValue: Option[String] = None

// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
Expand All @@ -2050,33 +2058,49 @@ case class FormatNumber(x: Expression, d: Expression)
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))

override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append("#,###,###,###,###,###,##0")

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
right.dataType match {
case IntegerType =>
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue = Some(dValue)
lastDIntValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append(defaultFormat)

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
}

lastDIntValue = Some(dValue)

numberFormat.applyLocalizedPattern(pattern.toString)
numberFormat.applyLocalizedPattern(pattern.toString)
}
case StringType =>
val dValue = dObject.asInstanceOf[UTF8String].toString
lastDStringValue match {
case Some(last) if last == dValue =>
case _ =>
pattern.delete(0, pattern.length)
lastDStringValue = Some(dValue)
if (dValue.toString.isEmpty) {
numberFormat.applyLocalizedPattern(defaultFormat)
} else {
numberFormat.applyLocalizedPattern(dValue)
}
}
}

x.dataType match {
Expand Down Expand Up @@ -2108,35 +2132,52 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
val lastDValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")

s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDValue) {
$pattern.append("#,###,###,###,###,###,##0");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
right.dataType match {
case IntegerType =>
val i = ctx.freshName("i")
val lastDIntValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDIntValue) {
$pattern.append("$defaultFormat");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
}
}
$lastDIntValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
$lastDValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
"""
"""
case StringType =>
val lastDStringValue =
ctx.addMutableState("String", "lastDValue", v => s"""$v = "$defaultFormat";""")
s"""
if (!$d.toString().equals($lastDStringValue)) {
$lastDStringValue = $d.toString();
if ($d.toString().isEmpty()) {
$numberFormat.applyLocalizedPattern("$defaultFormat");
} else {
$numberFormat.applyLocalizedPattern($d.toString());
}
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
"""
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,23 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)

checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(12831273.23481d),
Literal("###,###,###,###,###.###")), "12,831,273.235")
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
"123,123,324,123")
checkEvaluation(
FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
}

test("find in set") {
Expand Down

0 comments on commit 202fa3d

Please sign in to comment.