diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py index fa329b598d98b..732008eb05a35 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -15,10 +15,6 @@ # limitations under the License. # -import unittest - -from pyspark.errors import AnalysisException, PythonException -from pyspark.sql.functions import udf from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin @@ -36,28 +32,6 @@ def tearDownClass(cls): finally: super(ArrowPythonUDFParityTests, cls).tearDownClass() - def test_named_arguments_negative(self): - @udf("int") - def test_udf(a, b): - return a + b - - self.spark.udf.register("test_udf", test_udf) - - with self.assertRaisesRegex( - AnalysisException, - "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", - ): - self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() - - with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): - self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() - - with self.assertRaises(PythonException): - self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() - - with self.assertRaises(PythonException): - self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 23f302ec3c8d3..5a66d61cb66a2 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -17,7 +17,7 @@ import unittest -from pyspark.errors import PythonException, PySparkNotImplementedError +from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin @@ -197,6 +197,28 @@ def test_warn_no_args(self): " without arguments.", ) + def test_named_arguments_negative(self): + @udf("int") + def test_udf(a, b): + return a + b + + self.spark.udf.register("test_udf", test_udf) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + + with self.assertRaises(PythonException): + self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + + with self.assertRaises(PythonException): + self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod