Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
# limitations under the License.
#

import unittest

from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.functions import udf
from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin

Expand All @@ -36,28 +32,6 @@ def tearDownClass(cls):
finally:
super(ArrowPythonUDFParityTests, cls).tearDownClass()

def test_named_arguments_negative(self):
@udf("int")
def test_udf(a, b):
return a + b

self.spark.udf.register("test_udf", test_udf)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()


if __name__ == "__main__":
import unittest
Expand Down
24 changes: 23 additions & 1 deletion python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import unittest

from pyspark.errors import PythonException, PySparkNotImplementedError
from pyspark.errors import AnalysisException, PythonException, PySparkNotImplementedError
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
Expand Down Expand Up @@ -197,6 +197,28 @@ def test_warn_no_args(self):
" without arguments.",
)

def test_named_arguments_negative(self):
@udf("int")
def test_udf(a, b):
return a + b

self.spark.udf.register("test_udf", test_udf)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down