diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3ee485cbd95da..03266138024cd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -30,15 +30,22 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \ + _create_column_from_name from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf +# Note to developers: all of PySpark functions here take string as column names whenever possible. +# Namely, if columns are referred as arguments, they can be always both Column or string, +# even though there might be few exceptions for legacy or inevitable reasons. +# If you are fixing other language APIs together, also please note that Scala side is not the case +# since it requires to make every single overridden definition. -def _create_name_function(name, doc=""): - """ Create a function that takes a column name argument, by name""" + +def _create_function(name, doc=""): + """Create a PySpark function by its name""" def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) @@ -48,8 +55,11 @@ def _(col): return _ -def _create_function(name, doc=""): - """ Create a function that takes a Column object, by name""" +def _create_function_over_column(name, doc=""): + """Similar with `_create_function` but creates a PySpark function that takes a column + (as string as well). This is mainly for PySpark functions to take strings as + column names. + """ def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(_to_java_column(col)) @@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): sc = SparkContext._active_spark_context - # users might write ints for simplicity. This would throw an error on the JVM side. - jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1), - col2._jc if isinstance(col2, Column) else float(col2)) + # For legacy reasons, the arguments here can be implicitly converted into floats, + # if they are not columns or strings. + if isinstance(col1, Column): + arg1 = col1._jc + elif isinstance(col1, basestring): + arg1 = _create_column_from_name(col1) + else: + arg1 = float(col1) + + if isinstance(col2, Column): + arg2 = col2._jc + elif isinstance(col2, basestring): + arg2 = _create_column_from_name(col2) + else: + arg2 = float(col2) + + jc = getattr(sc._jvm.functions, name)(arg1, arg2) return Column(jc) _.__name__ = name _.__doc__ = doc @@ -96,8 +120,7 @@ def _(): >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) [Row(height=5, spark_user=True)] """ -_name_functions = { - # name functions take a column name as their argument +_functions = { 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', @@ -105,9 +128,7 @@ def _(): 'desc': 'Returns a sort expression based on the descending order of the given column name.', } -_functions = { - 'upper': 'Converts a string expression to upper case.', - 'lower': 'Converts a string expression to upper case.', +_functions_over_column = { 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', @@ -120,7 +141,7 @@ def _(): 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } -_functions_1_4 = { +_functions_1_4_over_column = { # unary math functions 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', @@ -155,7 +176,7 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } -_name_functions_2_4 = { +_functions_2_4 = { 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + ' column name, and null values return before non-null values.', 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + @@ -186,7 +207,7 @@ def _(): >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] """ -_functions_1_6 = { +_functions_1_6_over_column = { # unary math functions 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + ' the expression in a group.', @@ -203,7 +224,7 @@ def _(): 'collect_set': _collect_set_doc } -_functions_2_1 = { +_functions_2_1_over_column = { # unary math functions 'degrees': """ Converts an angle measured in radians to an approximately equivalent angle @@ -268,24 +289,24 @@ def _(): _functions_deprecated = { } -for _name, _doc in _name_functions.items(): - globals()[_name] = since(1.3)(_create_name_function(_name, _doc)) for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) -for _name, _doc in _functions_1_4.items(): - globals()[_name] = since(1.4)(_create_function(_name, _doc)) +for _name, _doc in _functions_over_column.items(): + globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc)) +for _name, _doc in _functions_1_4_over_column.items(): + globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc)) for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) -for _name, _doc in _functions_1_6.items(): - globals()[_name] = since(1.6)(_create_function(_name, _doc)) -for _name, _doc in _functions_2_1.items(): - globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _doc in _functions_1_6_over_column.items(): + globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc)) +for _name, _doc in _functions_2_1_over_column.items(): + globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) -for _name, _doc in _name_functions_2_4.items(): - globals()[_name] = since(2.4)(_create_name_function(_name, _doc)) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc @@ -1450,6 +1471,8 @@ def hash(*cols): # ---------------------- String/Binary functions ------------------------------ _string_functions = { + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to lower case.', 'ascii': 'Computes the numeric value of the first character of the string column.', 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', @@ -1460,7 +1483,7 @@ def hash(*cols): for _name, _doc in _string_functions.items(): - globals()[_name] = since(1.5)(_create_function(_name, _doc)) + globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index fe6660272e323..b77757342843d 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -129,6 +129,12 @@ def assert_close(a, b): df.select(functions.pow(df.a, 2.0)).collect()) assert_close([math.hypot(i, 2 * i) for i in range(10)], df.select(functions.hypot(df.a, df.b)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot("a", u"b")).collect()) + assert_close([math.hypot(i, 2) for i in range(10)], + df.select(functions.hypot("a", 2)).collect()) + assert_close([math.hypot(i, 2) for i in range(10)], + df.select(functions.hypot(df.a, 2)).collect()) def test_rand_functions(self): df = self.df @@ -151,7 +157,8 @@ def test_rand_functions(self): self.assertEqual(sorted(rndn1), sorted(rndn2)) def test_string_functions(self): - from pyspark.sql.functions import col, lit + from pyspark.sql import functions + from pyspark.sql.functions import col, lit, _string_functions df = self.spark.createDataFrame([['nick']], schema=['name']) self.assertRaisesRegexp( TypeError, @@ -162,6 +169,11 @@ def test_string_functions(self): TypeError, lambda: df.select(col('name').substr(long(0), long(1)))) + for name in _string_functions.keys(): + self.assertEqual( + df.select(getattr(functions, name)("name")).first()[0], + df.select(getattr(functions, name)(col("name"))).first()[0]) + def test_array_contains_function(self): from pyspark.sql.functions import array_contains