Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 54 additions & 21 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName):

@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.
as a UDF. The registered UDF can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:func:spark.catalog.registerFunction is an alias for :func:spark.udf.register. ?


:param name: name of the UDF
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.

:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`

>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
>>> spark.sql("SELECT stringLengthString('test')").collect()
Expand All @@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = spark.udf.register("slen", slen)
>>> spark.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
>>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
[Row(random_udf()=82)]
>>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to return to a underscore placeholder? It might seem confusing to users if not required

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to avoid generating the random hex value returned by PySpark. You can try spark.udf.register("add_one", add_one)

With the underscore placeholder, we can remove # doctest: +SKIP

>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""

# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF,
deterministic=f.deterministic)
if returnType is not None:
raise TypeError(
"Invalid returnType: None is expected when f is a UserDefinedFunction, "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too, I think here we should say returnType is disallowed to be set when f is a UserDefinedFunction.

"but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
return_udf = f
else:
udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()
if returnType is None:
returnType = StringType()
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
return_udf = register_udf._wrapped()
self._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf

@since(2.0)
def isCached(self, tableName):
Expand Down
51 changes: 36 additions & 15 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None):

@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statement.
as a UDF. The registered UDF can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:func:sqlContext.registerFunction is an alias for :func:spark.udf.register. ?


:param name: name of the UDF
:param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would simply say that data type is disallowed to set to returnType rather then None should be set.


:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`

>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
Expand All @@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]

>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = sqlContext.udf.register("slen", slen)
>>> sqlContext.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]

>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType, StringType
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
>>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=u'82')]
>>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
[Row(random_udf()=u'62')]
[Row(random_udf()=82)]
>>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP
>>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)

Expand Down Expand Up @@ -575,7 +596,7 @@ class UDFRegistration(object):
def __init__(self, sqlContext):
self.sqlContext = sqlContext

def register(self, name, f, returnType=StringType()):
def register(self, name, f, returnType=None):
return self.sqlContext.registerFunction(name, f, returnType)

def registerJavaFunction(self, name, javaClassName, returnType=None):
Expand Down
76 changes: 65 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,25 @@ def test_udf2(self):
self.assertEqual(4, res[0])

def test_udf3(self):
twoargs = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
self.assertEqual(twoargs.deterministic, True)
two_args = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
self.assertEqual(two_args.deterministic, True)
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], u'5')

def test_udf_registration_return_type_none(self):
two_args = self.spark.catalog.registerFunction(
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
self.assertEqual(two_args.deterministic, True)
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

def test_udf_registration_return_type_not_none(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
self.spark.catalog.registerFunction(
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())

def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
from pyspark.sql.functions import udf
Expand All @@ -401,12 +414,12 @@ def test_nondeterministic_udf2(self):
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
self.assertEqual(random_udf1.deterministic, False)
[row] = self.spark.sql("SELECT randInt()").collect()
self.assertEqual(row[0], "6")
self.assertEqual(row[0], 6)
[row] = self.spark.range(1).select(random_udf1()).collect()
self.assertEqual(row[0], "6")
self.assertEqual(row[0], 6)
[row] = self.spark.range(1).select(random_udf()).collect()
self.assertEqual(row[0], 6)
# render_doc() reproduces the help() exception without printing output
Expand Down Expand Up @@ -3581,7 +3594,7 @@ def tearDownClass(cls):
ReusedSQLTestCase.tearDownClass()

@property
def random_udf(self):
def nondeterministic_vectorized_udf(self):
from pyspark.sql.functions import pandas_udf

@pandas_udf('double')
Expand Down Expand Up @@ -3616,6 +3629,21 @@ def test_vectorized_udf_basic(self):
bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())

def test_register_nondeterministic_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf
from pyspark.rdd import PythonEvalType
import random
random_pandas_udf = pandas_udf(
lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
self.assertEqual(random_pandas_udf.deterministic, False)
self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
"randomPandasUDF", random_pandas_udf)
self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
[row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
self.assertEqual(row[0], 7)

def test_vectorized_udf_null_boolean(self):
from pyspark.sql.functions import pandas_udf, col
data = [(True,), (True,), (None,), (False,)]
Expand Down Expand Up @@ -3975,33 +4003,50 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_nondeterministic_udf(self):
def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
from pyspark.sql.functions import udf, pandas_udf, col

@pandas_udf('double')
def plus_ten(v):
return v + 10
random_udf = self.random_udf
random_udf = self.nondeterministic_vectorized_udf

df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()

self.assertEqual(random_udf.deterministic, False)
self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))

def test_nondeterministic_udf_in_aggregate(self):
def test_nondeterministic_vectorized_udf_in_aggregate(self):
from pyspark.sql.functions import pandas_udf, sum

df = self.spark.range(10)
random_udf = self.random_udf
random_udf = self.nondeterministic_vectorized_udf

with QuietTest(self.sc):
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
df.agg(sum(random_udf(df.id))).collect()

def test_register_vectorized_udf_basic(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, col, expr
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
col('id').cast('int').alias('b'))
original_add = pandas_udf(lambda x, y: x + y, IntegerType())
self.assertEqual(original_add.deterministic, True)
self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
new_add = self.spark.catalog.registerFunction("add1", original_add)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark.udf.register instead of spark.catalog.registerFunction?

res1 = df.select(new_add(col('a'), col('b')))
res2 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
expected = df.select(expr('a + b'))
self.assertEquals(expected.collect(), res1.collect())
self.assertEquals(expected.collect(), res2.collect())


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedSQLTestCase):
Expand Down Expand Up @@ -4037,6 +4082,15 @@ def test_simple(self):
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertFramesEqual(expected, result)

def test_register_group_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
'SQL_PANDAS_SCALAR_UDF'):
self.spark.catalog.registerFunction("foo_udf", foo_udf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.


def test_decorator(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
Expand Down