From 601bbf550d0d142cb766a3911ba253bce4099408 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 18:22:39 -0700 Subject: [PATCH] add python API support --- python/pyspark/sql/functions.py | 24 +++++++++++++++---- .../expressions/stringOperations.scala | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..8857ade058208 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,27 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 37e0206227a94..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -685,9 +685,13 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. @transient private var lastDValue: Int = -100 + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. @transient private val pattern: StringBuffer = new StringBuffer()