From 52274f73a69d37151876d49e0f709ad3b45ba8c3 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 14 Jul 2015 23:51:55 -0700 Subject: [PATCH 1/4] add support for udf_format_number and length for binary --- .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../expressions/stringOperations.scala | 92 ++++++++++++++++-- .../expressions/StringFunctionsSuite.scala | 54 +++++++++-- .../org/apache/spark/sql/functions.scala | 30 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 93 ++++++++++++++++--- 5 files changed, 238 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d2678ce860701..e0beafe710079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,11 +152,12 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 03b55ce5fe7cc..dde8186ad5d73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale import java.util.regex.Pattern -import org.apache.commons.lang3.StringUtils - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -553,17 +552,23 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + case NullType => defineCodeGen(ctx, ev, c => s"-1") + } } override def prettyName: String = "length" @@ -668,3 +673,74 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def foldable: Boolean = x.foldable && d.foldable + override def nullable: Boolean = x.nullable || d.nullable + + @transient + private var lastDValue: Int = -100 + + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override def eval(input: InternalRow): Any = { + val xObject = x.eval(input) + if (xObject == null) { + return null + } + + val dObject = d.eval(input) + + if (dObject == null || dObject.asInstanceOf[Int] < 0) { + throw new IllegalArgumentException( + s"Argument 2 of function FORMAT_NUMBER must be >= 0, but $dObject was found") + } + val dValue = dObject.asInstanceOf[Int] + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length()) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString()) + lastDValue = dValue; + numberFormat.applyPattern(dFormat.toPattern()) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index b19f4ee37a109..ebe25bcd4f21a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -426,4 +417,47 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + + checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c7deaca8437a1..ab9301f850be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1685,20 +1685,42 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Computes the length of a given string / binary value * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string column. + * Computes the length of a given string / binary column * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def length(columnName: String): Column = length(Column(columnName)) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + * + * @group string_funcs + * @since 1.5.0 + */ + def formatNumber(x: Column, d: Column): Column = FormatNumber(x.expr, d.expr) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + * + * @group string_funcs + * @since 1.5.0 + */ + def formatNumber(columnXName: String, columnDName: String): Column = { + formatNumber(Column(columnXName), Column(columnDName)) + } /** * Computes the Levenshtein distance of the two given strings. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70bd78737f69c..dad1232888a91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - test("Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) @@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest { val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( doubleData.select(pmod('a, 'b)), - Seq(Row(3.1000000000000005)) // same as hive + Seq(Row(3.1000000000000005)) // same as hive ) checkAnswer( doubleData.select(pmod(lit(2), lit(Int.MaxValue))), Seq(Row(2)) ) } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + formatNumber($"f", $"e"), + formatNumber("f", "e")), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } } From 3ebe288e6f659deb9b4a1de7cc11ea1fc7f5db2b Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 00:37:13 -0700 Subject: [PATCH 2/4] update as feedback --- .../expressions/stringOperations.scala | 8 +++---- .../expressions/StringFunctionsSuite.scala | 3 +-- .../org/apache/spark/sql/functions.scala | 24 ++++++++++--------- .../spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index dde8186ad5d73..37e0206227a94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -567,7 +567,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") - case NullType => defineCodeGen(ctx, ev, c => s"-1") } } @@ -685,8 +684,6 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) - override def foldable: Boolean = x.foldable && d.foldable - override def nullable: Boolean = x.nullable || d.nullable @transient private var lastDValue: Int = -100 @@ -706,8 +703,7 @@ case class FormatNumber(x: Expression, d: Expression) val dObject = d.eval(input) if (dObject == null || dObject.asInstanceOf[Int] < 0) { - throw new IllegalArgumentException( - s"Argument 2 of function FORMAT_NUMBER must be >= 0, but $dObject was found") + return null } val dValue = dObject.asInstanceOf[Int] @@ -742,5 +738,7 @@ case class FormatNumber(x: Expression, d: Expression) UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) } } + + override def prettyName: String = "format_number" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index ebe25bcd4f21a..5d7763bedf6bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -441,8 +441,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) - - checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null)) } test("number format") { @@ -453,6 +451,7 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) checkEvaluation( FormatNumber( Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ab9301f850be5..d6da284a4c788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1685,7 +1685,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string / binary value + * Computes the length of a given string / binary value. * * @group string_funcs * @since 1.5.0 @@ -1693,7 +1693,7 @@ object functions { def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string / binary column + * Computes the length of a given string / binary column. * * @group string_funcs * @since 1.5.0 @@ -1701,25 +1701,27 @@ object functions { def length(columnName: String): Column = length(Column(columnName)) /** - * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, - * and returns the result as a string. If D is 0, the result has no decimal point or - * fractional part. + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. * * @group string_funcs * @since 1.5.0 */ - def formatNumber(x: Column, d: Column): Column = FormatNumber(x.expr, d.expr) + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** - * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, - * and returns the result as a string. If D is 0, the result has no decimal point or - * fractional part. + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. * * @group string_funcs * @since 1.5.0 */ - def formatNumber(columnXName: String, columnDName: String): Column = { - formatNumber(Column(columnXName), Column(columnDName)) + def format_number(columnXName: String, d: Int): Column = { + format_number(Column(columnXName), d) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index dad1232888a91..6dccdd857b453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -465,8 +465,8 @@ class DataFrameFunctionsSuite extends QueryTest { checkAnswer( df.select( - formatNumber($"f", $"e"), - formatNumber("f", "e")), + format_number($"f", 4), + format_number("f", 4)), Row("5.0000", "5.0000")) checkAnswer( From 601bbf550d0d142cb766a3911ba253bce4099408 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 18:22:39 -0700 Subject: [PATCH 3/4] add python API support --- python/pyspark/sql/functions.py | 24 +++++++++++++++---- .../expressions/stringOperations.scala | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..8857ade058208 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,27 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 37e0206227a94..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -685,9 +685,13 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. @transient private var lastDValue: Int = -100 + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. @transient private val pattern: StringBuffer = new StringBuffer() From e534b87a125d264123216025d16d61da327f837d Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 19:40:20 -0700 Subject: [PATCH 4/4] python api style issue --- python/pyspark/sql/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8857ade058208..e0816b3e654bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -516,6 +516,7 @@ def length(col): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col))) + @ignore_unicode_prefix @since(1.5) def format_number(col, d):