Skip to content

Commit

Permalink
[SPARK-18777][PYTHON][SQL] Return UDF from udf.register
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

- Move udf wrapping code from `functions.udf` to `functions.UserDefinedFunction`.
- Return wrapped udf from `catalog.registerFunction` and dependent methods.
- Update docstrings in `catalog.registerFunction` and `SQLContext.registerFunction`.
- Unit tests.

## How was this patch tested?

- Existing unit tests and docstests.
- Additional tests covering new feature.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17831 from zero323/SPARK-18777.
  • Loading branch information
zero323 authored and gatorsmile committed May 7, 2017
1 parent cafca54 commit 63d90e7
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
11 changes: 8 additions & 3 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,28 @@ def registerFunction(self, name, f, returnType=StringType()):
:param name: name of the UDF
:param f: python function
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
>>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x))
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
>>> spark.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
"""
udf = UserDefinedFunction(f, returnType, name)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()

@since(2.0)
def isCached(self, tableName):
Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()):
:param name: name of the UDF
:param f: python function
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
"""
self.sparkSession.catalog.registerFunction(name, f, returnType)
return self.sparkSession.catalog.registerFunction(name, f, returnType)

@ignore_unicode_prefix
@since(2.1)
Expand Down
23 changes: 14 additions & 9 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,19 @@ def __call__(self, *cols):
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))

def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
"""
@functools.wraps(self.func)
def wrapper(*args):
return self(*args)

wrapper.func = self.func
wrapper.returnType = self.returnType

return wrapper


@since(1.3)
def udf(f=None, returnType=StringType()):
Expand Down Expand Up @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()):
"""
def _udf(f, returnType=StringType()):
udf_obj = UserDefinedFunction(f, returnType)

@functools.wraps(f)
def wrapper(*args):
return udf_obj(*args)

wrapper.func = udf_obj.func
wrapper.returnType = udf_obj.returnType

return wrapper
return udf_obj._wrapped()

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])

def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())

self.assertListEqual(
df.selectExpr("add_three(id) AS plus_three").collect(),
df.select(add_three("id").alias("plus_three")).collect()
)

def test_wholefile_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
Expand Down

0 comments on commit 63d90e7

Please sign in to comment.