-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-22978] [PySpark] Register Vectorized UDFs for SQL Statement #20171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e0c8e1
f41e74e
3983bcb
fe8dcbe
3411dcc
b801e70
3c08f3d
423c832
a052a2d
99fc0b2
47bce1e
d73ab3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName): | |
|
|
||
| @ignore_unicode_prefix | ||
| @since(2.0) | ||
| def registerFunction(self, name, f, returnType=StringType()): | ||
| def registerFunction(self, name, f, returnType=None): | ||
| """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` | ||
| as a UDF. The registered UDF can be used in SQL statement. | ||
| as a UDF. The registered UDF can be used in SQL statements. | ||
|
|
||
| In addition to a name and the function itself, the return type can be optionally specified. | ||
| When the return type is not given it default to a string and conversion will automatically | ||
| be done. For any other return type, the produced object must match the specified type. | ||
| :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. | ||
|
|
||
| :param name: name of the UDF | ||
| :param f: a Python function, or a wrapped/native UserDefinedFunction | ||
| :param returnType: a :class:`pyspark.sql.types.DataType` object | ||
| :return: a wrapped :class:`UserDefinedFunction` | ||
| In addition to a name and the function itself, `returnType` can be optionally specified. | ||
| 1) When f is a Python function, `returnType` defaults to a string. The produced object must | ||
| match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return | ||
| type of the given UDF as the return type of the registered UDF. The input parameter | ||
| `returnType` is None by default. If given by users, the value must be None. | ||
|
|
||
| :param name: name of the UDF in SQL statements. | ||
| :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either | ||
| row-at-a-time or vectorized. | ||
| :param returnType: the return type of the registered UDF. | ||
| :return: a wrapped/native :class:`UserDefinedFunction` | ||
|
|
||
| >>> strlen = spark.catalog.registerFunction("stringLengthString", len) | ||
| >>> spark.sql("SELECT stringLengthString('test')").collect() | ||
|
|
@@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()): | |
| >>> spark.sql("SELECT stringLengthInt('test')").collect() | ||
| [Row(stringLengthInt(test)=4)] | ||
|
|
||
| >>> from pyspark.sql.types import IntegerType | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> slen = udf(lambda s: len(s), IntegerType()) | ||
| >>> _ = spark.udf.register("slen", slen) | ||
| >>> spark.sql("SELECT slen('test')").collect() | ||
| [Row(slen(test)=4)] | ||
|
|
||
| >>> import random | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> from pyspark.sql.types import IntegerType, StringType | ||
| >>> from pyspark.sql.types import IntegerType | ||
| >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() | ||
| >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) | ||
| >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) | ||
| >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'82')] | ||
| >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'62')] | ||
| [Row(random_udf()=82)] | ||
| >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP | ||
| [Row(<lambda>()=26)] | ||
|
|
||
| >>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP | ||
| ... def add_one(x): | ||
| ... return x + 1 | ||
| ... | ||
| >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to return to a underscore placeholder? It might seem confusing to users if not required There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to avoid generating the random hex value returned by PySpark. You can try With the underscore placeholder, we can remove |
||
| >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP | ||
| [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] | ||
| """ | ||
|
|
||
| # This is to check whether the input function is a wrapped/native UserDefinedFunction | ||
| if hasattr(f, 'asNondeterministic'): | ||
| udf = UserDefinedFunction(f.func, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF, | ||
| deterministic=f.deterministic) | ||
| if returnType is not None: | ||
| raise TypeError( | ||
| "Invalid returnType: None is expected when f is a UserDefinedFunction, " | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too, I think here we should say |
||
| "but got %s." % returnType) | ||
| if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, | ||
| PythonEvalType.SQL_PANDAS_SCALAR_UDF]: | ||
| raise ValueError( | ||
| "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") | ||
| register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, | ||
| evalType=f.evalType, | ||
| deterministic=f.deterministic) | ||
| return_udf = f | ||
| else: | ||
| udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
| self._jsparkSession.udf().registerPython(name, udf._judf) | ||
| return udf._wrapped() | ||
| if returnType is None: | ||
| returnType = StringType() | ||
| register_udf = UserDefinedFunction(f, returnType=returnType, name=name, | ||
| evalType=PythonEvalType.SQL_BATCHED_UDF) | ||
| return_udf = register_udf._wrapped() | ||
| self._jsparkSession.udf().registerPython(name, register_udf._judf) | ||
| return return_udf | ||
|
|
||
| @since(2.0) | ||
| def isCached(self, tableName): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None): | |
|
|
||
| @ignore_unicode_prefix | ||
| @since(1.2) | ||
| def registerFunction(self, name, f, returnType=StringType()): | ||
| def registerFunction(self, name, f, returnType=None): | ||
| """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` | ||
| as a UDF. The registered UDF can be used in SQL statement. | ||
| as a UDF. The registered UDF can be used in SQL statements. | ||
|
|
||
| In addition to a name and the function itself, the return type can be optionally specified. | ||
| When the return type is not given it default to a string and conversion will automatically | ||
| be done. For any other return type, the produced object must match the specified type. | ||
| :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :func: |
||
|
|
||
| :param name: name of the UDF | ||
| :param f: a Python function, or a wrapped/native UserDefinedFunction | ||
| :param returnType: a :class:`pyspark.sql.types.DataType` object | ||
| :return: a wrapped :class:`UserDefinedFunction` | ||
| In addition to a name and the function itself, `returnType` can be optionally specified. | ||
| 1) When f is a Python function, `returnType` defaults to a string. The produced object must | ||
| match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return | ||
| type of the given UDF as the return type of the registered UDF. The input parameter | ||
| `returnType` is None by default. If given by users, the value must be None. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we would simply say that data type is disallowed to set to |
||
|
|
||
| :param name: name of the UDF in SQL statements. | ||
| :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either | ||
| row-at-a-time or vectorized. | ||
| :param returnType: the return type of the registered UDF. | ||
| :return: a wrapped/native :class:`UserDefinedFunction` | ||
|
|
||
| >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) | ||
| >>> sqlContext.sql("SELECT stringLengthString('test')").collect() | ||
|
|
@@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()): | |
| >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() | ||
| [Row(stringLengthInt(test)=4)] | ||
|
|
||
| >>> from pyspark.sql.types import IntegerType | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> slen = udf(lambda s: len(s), IntegerType()) | ||
| >>> _ = sqlContext.udf.register("slen", slen) | ||
| >>> sqlContext.sql("SELECT slen('test')").collect() | ||
| [Row(slen(test)=4)] | ||
|
|
||
| >>> import random | ||
| >>> from pyspark.sql.functions import udf | ||
| >>> from pyspark.sql.types import IntegerType, StringType | ||
| >>> from pyspark.sql.types import IntegerType | ||
| >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() | ||
| >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) | ||
| >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) | ||
| >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'82')] | ||
| >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP | ||
| [Row(random_udf()=u'62')] | ||
| [Row(random_udf()=82)] | ||
| >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP | ||
| [Row(<lambda>()=26)] | ||
|
|
||
| >>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP | ||
| ... def add_one(x): | ||
| ... return x + 1 | ||
| ... | ||
| >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP | ||
| >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP | ||
| [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] | ||
| """ | ||
| return self.sparkSession.catalog.registerFunction(name, f, returnType) | ||
|
|
||
|
|
@@ -575,7 +596,7 @@ class UDFRegistration(object): | |
| def __init__(self, sqlContext): | ||
| self.sqlContext = sqlContext | ||
|
|
||
| def register(self, name, f, returnType=StringType()): | ||
| def register(self, name, f, returnType=None): | ||
| return self.sqlContext.registerFunction(name, f, returnType) | ||
|
|
||
| def registerJavaFunction(self, name, javaClassName, returnType=None): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -379,12 +379,25 @@ def test_udf2(self): | |
| self.assertEqual(4, res[0]) | ||
|
|
||
| def test_udf3(self): | ||
| twoargs = self.spark.catalog.registerFunction( | ||
| "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) | ||
| self.assertEqual(twoargs.deterministic, True) | ||
| two_args = self.spark.catalog.registerFunction( | ||
| "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) | ||
| self.assertEqual(two_args.deterministic, True) | ||
| [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() | ||
| self.assertEqual(row[0], u'5') | ||
|
|
||
| def test_udf_registration_return_type_none(self): | ||
| two_args = self.spark.catalog.registerFunction( | ||
| "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) | ||
| self.assertEqual(two_args.deterministic, True) | ||
| [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() | ||
| self.assertEqual(row[0], 5) | ||
|
|
||
| def test_udf_registration_return_type_not_none(self): | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(TypeError, "Invalid returnType"): | ||
| self.spark.catalog.registerFunction( | ||
| "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) | ||
|
|
||
| def test_nondeterministic_udf(self): | ||
| # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations | ||
| from pyspark.sql.functions import udf | ||
|
|
@@ -401,12 +414,12 @@ def test_nondeterministic_udf2(self): | |
| from pyspark.sql.functions import udf | ||
| random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() | ||
| self.assertEqual(random_udf.deterministic, False) | ||
| random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) | ||
| random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) | ||
| self.assertEqual(random_udf1.deterministic, False) | ||
| [row] = self.spark.sql("SELECT randInt()").collect() | ||
| self.assertEqual(row[0], "6") | ||
| self.assertEqual(row[0], 6) | ||
| [row] = self.spark.range(1).select(random_udf1()).collect() | ||
| self.assertEqual(row[0], "6") | ||
| self.assertEqual(row[0], 6) | ||
| [row] = self.spark.range(1).select(random_udf()).collect() | ||
| self.assertEqual(row[0], 6) | ||
| # render_doc() reproduces the help() exception without printing output | ||
|
|
@@ -3581,7 +3594,7 @@ def tearDownClass(cls): | |
| ReusedSQLTestCase.tearDownClass() | ||
|
|
||
| @property | ||
| def random_udf(self): | ||
| def nondeterministic_vectorized_udf(self): | ||
| from pyspark.sql.functions import pandas_udf | ||
|
|
||
| @pandas_udf('double') | ||
|
|
@@ -3616,6 +3629,21 @@ def test_vectorized_udf_basic(self): | |
| bool_f(col('bool'))) | ||
| self.assertEquals(df.collect(), res.collect()) | ||
|
|
||
| def test_register_nondeterministic_vectorized_udf_basic(self): | ||
| from pyspark.sql.functions import pandas_udf | ||
| from pyspark.rdd import PythonEvalType | ||
| import random | ||
| random_pandas_udf = pandas_udf( | ||
| lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() | ||
| self.assertEqual(random_pandas_udf.deterministic, False) | ||
| self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) | ||
| nondeterministic_pandas_udf = self.spark.catalog.registerFunction( | ||
| "randomPandasUDF", random_pandas_udf) | ||
| self.assertEqual(nondeterministic_pandas_udf.deterministic, False) | ||
| self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) | ||
| [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() | ||
| self.assertEqual(row[0], 7) | ||
|
|
||
| def test_vectorized_udf_null_boolean(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| data = [(True,), (True,), (None,), (False,)] | ||
|
|
@@ -3975,33 +4003,50 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): | |
| finally: | ||
| self.spark.conf.set("spark.sql.session.timeZone", orig_tz) | ||
|
|
||
| def test_nondeterministic_udf(self): | ||
| def test_nondeterministic_vectorized_udf(self): | ||
| # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations | ||
| from pyspark.sql.functions import udf, pandas_udf, col | ||
|
|
||
| @pandas_udf('double') | ||
| def plus_ten(v): | ||
| return v + 10 | ||
| random_udf = self.random_udf | ||
| random_udf = self.nondeterministic_vectorized_udf | ||
|
|
||
| df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) | ||
| result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() | ||
|
|
||
| self.assertEqual(random_udf.deterministic, False) | ||
| self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) | ||
|
|
||
| def test_nondeterministic_udf_in_aggregate(self): | ||
| def test_nondeterministic_vectorized_udf_in_aggregate(self): | ||
| from pyspark.sql.functions import pandas_udf, sum | ||
|
|
||
| df = self.spark.range(10) | ||
| random_udf = self.random_udf | ||
| random_udf = self.nondeterministic_vectorized_udf | ||
|
|
||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): | ||
| df.groupby(df.id).agg(sum(random_udf(df.id))).collect() | ||
| with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): | ||
| df.agg(sum(random_udf(df.id))).collect() | ||
|
|
||
| def test_register_vectorized_udf_basic(self): | ||
| from pyspark.rdd import PythonEvalType | ||
| from pyspark.sql.functions import pandas_udf, col, expr | ||
| df = self.spark.range(10).select( | ||
| col('id').cast('int').alias('a'), | ||
| col('id').cast('int').alias('b')) | ||
| original_add = pandas_udf(lambda x, y: x + y, IntegerType()) | ||
| self.assertEqual(original_add.deterministic, True) | ||
| self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) | ||
| new_add = self.spark.catalog.registerFunction("add1", original_add) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| res1 = df.select(new_add(col('a'), col('b'))) | ||
| res2 = self.spark.sql( | ||
| "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") | ||
| expected = df.select(expr('a + b')) | ||
| self.assertEquals(expected.collect(), res1.collect()) | ||
| self.assertEquals(expected.collect(), res2.collect()) | ||
|
|
||
|
|
||
| @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") | ||
| class GroupbyApplyTests(ReusedSQLTestCase): | ||
|
|
@@ -4037,6 +4082,15 @@ def test_simple(self): | |
| expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) | ||
| self.assertFramesEqual(expected, result) | ||
|
|
||
| def test_register_group_map_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' | ||
| 'SQL_PANDAS_SCALAR_UDF'): | ||
| self.spark.catalog.registerFunction("foo_udf", foo_udf) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
|
|
||
| def test_decorator(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| df = self.data | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:func:
spark.catalog.registerFunctionis an alias for :func:spark.udf.register. ?