Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well #24121

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,22 @@

from pyspark import since, SparkContext
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \
_create_column_from_name
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf

# Note to developers: all of PySpark functions here take string as column names whenever possible.
# Namely, if columns are referred as arguments, they can be always both Column or string,
# even though there might be few exceptions for legacy or inevitable reasons.
# If you are fixing other language APIs together, also please note that Scala side is not the case
# since it requires to make every single overridden definition.

def _create_name_function(name, doc=""):
""" Create a function that takes a column name argument, by name"""

def _create_function(name, doc=""):
"""Create a PySpark function by its name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
Expand All @@ -48,8 +55,11 @@ def _(col):
return _


def _create_function(name, doc=""):
""" Create a function that takes a Column object, by name"""
def _create_function_over_column(name, doc=""):
"""Similar with `_create_function` but creates a PySpark function that takes a column
(as string as well). This is mainly for PySpark functions to take strings as
column names.
"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
Expand All @@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""):
""" Create a binary mathfunction by name"""
def _(col1, col2):
sc = SparkContext._active_spark_context
# users might write ints for simplicity. This would throw an error on the JVM side.
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
col2._jc if isinstance(col2, Column) else float(col2))
# For legacy reasons, the arguments here can be implicitly converted into floats,
# if they are not columns or strings.
if isinstance(col1, Column):
arg1 = col1._jc
elif isinstance(col1, basestring):
arg1 = _create_column_from_name(col1)
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
else:
arg1 = float(col1)

if isinstance(col2, Column):
arg2 = col2._jc
elif isinstance(col2, basestring):
arg2 = _create_column_from_name(col2)
else:
arg2 = float(col2)

jc = getattr(sc._jvm.functions, name)(arg1, arg2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
Expand All @@ -96,18 +120,15 @@ def _():
>>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
[Row(height=5, spark_user=True)]
"""
_name_functions = {
# name functions take a column name as their argument
_functions = {
'lit': _lit_doc,
'col': 'Returns a :class:`Column` based on the given column name.',
'column': 'Returns a :class:`Column` based on the given column name.',
'asc': 'Returns a sort expression based on the ascending order of the given column name.',
'desc': 'Returns a sort expression based on the descending order of the given column name.',
}

_functions = {
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.',
_functions_over_column = {
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolute value.',

Expand All @@ -120,7 +141,7 @@ def _():
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

_functions_1_4 = {
_functions_1_4_over_column = {
# unary math functions
'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
Expand Down Expand Up @@ -155,7 +176,7 @@ def _():
'bitwiseNOT': 'Computes bitwise not.',
}

_name_functions_2_4 = {
_functions_2_4 = {
'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
' column name, and null values return before non-null values.',
'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
Expand Down Expand Up @@ -186,7 +207,7 @@ def _():
>>> df2.agg(collect_set('age')).collect()
[Row(collect_set(age)=[5, 2])]
"""
_functions_1_6 = {
_functions_1_6_over_column = {
# unary math functions
'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
' the expression in a group.',
Expand All @@ -203,7 +224,7 @@ def _():
'collect_set': _collect_set_doc
}

_functions_2_1 = {
_functions_2_1_over_column = {
# unary math functions
'degrees': """
Converts an angle measured in radians to an approximately equivalent angle
Expand Down Expand Up @@ -268,24 +289,24 @@ def _():
_functions_deprecated = {
}

for _name, _doc in _name_functions.items():
globals()[_name] = since(1.3)(_create_name_function(_name, _doc))
for _name, _doc in _functions.items():
globals()[_name] = since(1.3)(_create_function(_name, _doc))
for _name, _doc in _functions_1_4.items():
globals()[_name] = since(1.4)(_create_function(_name, _doc))
for _name, _doc in _functions_over_column.items():
globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc))
for _name, _doc in _functions_1_4_over_column.items():
globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc))
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
for _name, _doc in _window_functions.items():
globals()[_name] = since(1.6)(_create_window_function(_name, _doc))
for _name, _doc in _functions_1_6.items():
globals()[_name] = since(1.6)(_create_function(_name, _doc))
for _name, _doc in _functions_2_1.items():
globals()[_name] = since(2.1)(_create_function(_name, _doc))
for _name, _doc in _functions_1_6_over_column.items():
globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc))
for _name, _doc in _functions_2_1_over_column.items():
globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc))
for _name, _message in _functions_deprecated.items():
globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
for _name, _doc in _name_functions_2_4.items():
globals()[_name] = since(2.4)(_create_name_function(_name, _doc))
for _name, _doc in _functions_2_4.items():
globals()[_name] = since(2.4)(_create_function(_name, _doc))
del _name, _doc


Expand Down Expand Up @@ -1450,6 +1471,8 @@ def hash(*cols):
# ---------------------- String/Binary functions ------------------------------

_string_functions = {
'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to lower case.',
'ascii': 'Computes the numeric value of the first character of the string column.',
'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
Expand All @@ -1460,7 +1483,7 @@ def hash(*cols):


for _name, _doc in _string_functions.items():
globals()[_name] = since(1.5)(_create_function(_name, _doc))
globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc))
del _name, _doc


Expand Down
14 changes: 13 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def assert_close(a, b):
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", u"b")).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot("a", 2)).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot(df.a, 2)).collect())

def test_rand_functions(self):
df = self.df
Expand All @@ -151,7 +157,8 @@ def test_rand_functions(self):
self.assertEqual(sorted(rndn1), sorted(rndn2))

def test_string_functions(self):
from pyspark.sql.functions import col, lit
from pyspark.sql import functions
from pyspark.sql.functions import col, lit, _string_functions
df = self.spark.createDataFrame([['nick']], schema=['name'])
self.assertRaisesRegexp(
TypeError,
Expand All @@ -162,6 +169,11 @@ def test_string_functions(self):
TypeError,
lambda: df.select(col('name').substr(long(0), long(1))))

for name in _string_functions.keys():
self.assertEqual(
df.select(getattr(functions, name)("name")).first()[0],
df.select(getattr(functions, name)(col("name"))).first()[0])

def test_array_contains_function(self):
from pyspark.sql.functions import array_contains

Expand Down