Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,24 @@ def test_register(self):
df = self.spark.range(1).selectExpr(
"array(1, 2, 3) as array",
)
str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: str(x), useArrow=True))

# To verify that Arrow optimization is on
self.assertIn(
df.selectExpr("str_repr(array) AS str_id").first()[0],
["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
# The input is a NumPy array when the Arrow optimization is on
)
with self.temp_func("str_repr"):
str_repr_func = self.spark.udf.register(
"str_repr", udf(lambda x: str(x), useArrow=True)
)

# To verify that a UserDefinedFunction is returned
self.assertListEqual(
df.selectExpr("str_repr(array) AS str_id").collect(),
df.select(str_repr_func("array").alias("str_id")).collect(),
)
# To verify that Arrow optimization is on
self.assertIn(
df.selectExpr("str_repr(array) AS str_id").first()[0],
["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
# The input is a NumPy array when the Arrow optimization is on
)

# To verify that a UserDefinedFunction is returned
self.assertListEqual(
df.selectExpr("str_repr(array) AS str_id").collect(),
df.select(str_repr_func("array").alias("str_id")).collect(),
)

def test_nested_array_input(self):
df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
Expand Down Expand Up @@ -275,22 +279,23 @@ def test_named_arguments_negative(self):
def test_udf(a, b):
return a + b

self.spark.udf.register("test_udf", test_udf)
with self.temp_func("test_udf"):
self.spark.udf.register("test_udf", test_udf)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()

with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()

def test_udf_with_udt(self):
for fallback in [False, True]:
Expand Down
31 changes: 19 additions & 12 deletions python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,19 @@ def test_register_vectorized_udf_basic(self):
)

self.assertEqual(sum_arrow_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf", sum_arrow_udf)
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
q = "SELECT sum_arrow_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)

with self.temp_func("sum_arrow_udf"):
group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf", sum_arrow_udf)
self.assertEqual(
group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
)
q = """
SELECT sum_arrow_udf(v1)
FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2
"""
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)

def test_grouped_with_empty_partition(self):
import pyarrow as pa
Expand All @@ -516,10 +523,10 @@ def max_udf(v):
return float(pa.compute.max(v).as_py())

df = self.spark.range(0, 100)
self.spark.udf.register("max_udf", max_udf)

with self.tempView("table"):
with self.tempView("table"), self.temp_func("max_udf"):
df.createTempView("table")
self.spark.udf.register("max_udf", max_udf)

agg1 = df.agg(max_udf(df["id"]))
agg2 = self.spark.sql("select max_udf(id) from table")
Expand All @@ -546,7 +553,7 @@ def test_named_arguments(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf

with self.tempView("v"):
with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

Expand Down Expand Up @@ -575,7 +582,7 @@ def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf

with self.tempView("v"):
with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

Expand Down Expand Up @@ -615,7 +622,7 @@ def weighted_mean(**kwargs):

return np.average(kwargs["v"], weights=kwargs["w"])

with self.tempView("v"):
with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

Expand Down Expand Up @@ -660,7 +667,7 @@ def test_named_arguments_and_defaults(self):
def biased_sum(v, w=None):
return pa.compute.sum(v).as_py() + (pa.compute.sum(w).as_py() if w is not None else 100)

with self.tempView("v"):
with self.tempView("v"), self.temp_func("biased_sum"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)

Expand Down
Loading