Skip to content

Commit

Permalink
add python API support
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Jul 16, 2015
1 parent 3ebe288 commit 601bbf5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
24 changes: 19 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
'coalesce',
'countDistinct',
'explode',
'format_number',
'length',
'log2',
'md5',
'monotonicallyIncreasingId',
Expand All @@ -47,7 +49,6 @@
'sha1',
'sha2',
'sparkPartitionId',
'strlen',
'struct',
'udf',
'when']
Expand Down Expand Up @@ -506,14 +507,27 @@ def sparkPartitionId():

@ignore_unicode_prefix
@since(1.5)
def strlen(col):
"""Calculates the length of a string expression.
def length(col):
"""Calculates the length of a string or binary expression.
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
[Row(length=3)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.strlen(_to_java_column(col)))
return Column(sc._jvm.functions.length(_to_java_column(col)))

@ignore_unicode_prefix
@since(1.5)
def format_number(col, d):
"""Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
and returns the result as a string.
:param col: the column name of the numeric value to be formatted
:param d: the N decimal places
>>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
[Row(v=u'5.0000')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.format_number(_to_java_column(col), d))


@ignore_unicode_prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,13 @@ case class FormatNumber(x: Expression, d: Expression)
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
@transient
private var lastDValue: Int = -100

// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
@transient
private val pattern: StringBuffer = new StringBuffer()

Expand Down

0 comments on commit 601bbf5

Please sign in to comment.