Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45266][PYTHON][FOLLOWUP] Fix to resolve UnresolvedPolymorphicPythonUDTF when the table argument is specified as a named argument #43355

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
272 changes: 154 additions & 118 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,65 +2086,89 @@ def terminate(self):
# This is a basic example.
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)

base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT partition_col, total
FROM test_udtf({table_arg})
ORDER BY 1, 2
"""

for table_arg in [
"TABLE(t) PARTITION BY partition_col - 1",
"row => TABLE(t) PARTITION BY partition_col - 1",
]:
with self.subTest(table_arg=table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[Row(partition_col=x, total=3) for x in range(1, 21)],
)
SELECT partition_col, total
FROM test_udtf(TABLE(t) PARTITION BY partition_col - 1)
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, total=3) for x in range(1, 21)],
)

base_query = """
WITH t AS (
SELECT {str_first} AS partition_col, id AS input FROM range(0, 2)
UNION ALL
SELECT {str_second} AS partition_col, id AS input FROM range(0, 2)
)
SELECT partition_col, total
FROM test_udtf({table_arg})
ORDER BY 1, 2
"""

# These cases partition by constant values.
for str_first, str_second, result_first, result_second in (
("123", "456", 123, 456),
("123", "NULL", None, 123),
):
assertDataFrameEqual(
self.spark.sql(
f"""
WITH t AS (
SELECT {str_first} AS partition_col, id AS input FROM range(0, 2)
UNION ALL
SELECT {str_second} AS partition_col, id AS input FROM range(0, 2)
for table_arg in [
"TABLE(t) PARTITION BY partition_col",
"row => TABLE(t) PARTITION BY partition_col",
]:
with self.subTest(str_first=str_first, str_second=str_second, table_arg=table_arg):
assertDataFrameEqual(
self.spark.sql(
base_query.format(
str_first=str_first, str_second=str_second, table_arg=table_arg
)
),
[
Row(partition_col=result_first, total=1),
Row(partition_col=result_second, total=1),
],
)
SELECT partition_col, total
FROM test_udtf(TABLE(t) PARTITION BY partition_col)
ORDER BY 1, 2
"""
).collect(),
[
Row(partition_col=result_first, total=1),
Row(partition_col=result_second, total=1),
],
)

# Combine a lateral join with a TABLE argument with PARTITION BY .
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 3)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 3)

base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 3)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 3)
)
SELECT v.a, v.b, f.partition_col, f.total
FROM VALUES (0, 1) AS v(a, b),
LATERAL test_udtf({table_arg}) f
ORDER BY 1, 2, 3, 4
"""

for table_arg in [
"TABLE(t) PARTITION BY partition_col - 1",
"row => TABLE(t) PARTITION BY partition_col - 1",
]:
with self.subTest(func_call=table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[
Row(a=0, b=1, partition_col=1, total=3),
Row(a=0, b=1, partition_col=2, total=3),
],
)
SELECT v.a, v.b, f.partition_col, f.total
FROM VALUES (0, 1) AS v(a, b),
LATERAL test_udtf(TABLE(t) PARTITION BY partition_col - 1) f
ORDER BY 1, 2, 3, 4
"""
).collect(),
[Row(a=0, b=1, partition_col=1, total=3), Row(a=0, b=1, partition_col=2, total=3)],
)

def test_udtf_with_table_argument_and_partition_by_and_order_by(self):
class TestUDTF:
Expand Down Expand Up @@ -2172,29 +2196,35 @@ def terminate(self):

func = udtf(TestUDTF, returnType="partition_col: int, last: int")
self.spark.udtf.register("test_udtf", func)

base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT partition_col, last
FROM test_udtf(
{table_arg},
partition_col => 'partition_col')
ORDER BY 1, 2
"""

for order_by_str, result_val in (
("input ASC", 2),
("input + 1 ASC", 2),
("input DESC", 1),
("input - 1 DESC", 1),
):
assertDataFrameEqual(
self.spark.sql(
f"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
for table_arg in [
f"TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str}",
f"row => TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str}",
]:
with self.subTest(table_arg=table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[Row(partition_col=x, last=result_val) for x in range(1, 21)],
)
SELECT partition_col, last
FROM test_udtf(
row => TABLE(t) PARTITION BY partition_col - 1 ORDER BY {order_by_str},
partition_col => 'partition_col')
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, last=result_val) for x in range(1, 21)],
)

def test_udtf_with_table_argument_with_single_partition(self):
class TestUDTF:
Expand All @@ -2218,23 +2248,27 @@ def terminate(self):

func = udtf(TestUDTF, returnType="count: int, total: int, last: int")
self.spark.udtf.register("test_udtf", func)
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)

base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT count, total, last
FROM test_udtf({table_arg})
ORDER BY 1, 2
"""

for table_arg in [
"TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col)",
"row => TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col)",
]:
with self.subTest(table_arg=table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[Row(count=40, total=60, last=2)],
)
SELECT count, total, last
FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY (input, partition_col))
ORDER BY 1, 2
"""
).collect(),
[
Row(count=40, total=60, last=2),
],
)

def test_udtf_with_table_argument_with_single_partition_from_analyze(self):
@udtf
Expand All @@ -2245,7 +2279,7 @@ def __init__(self):
self._last = None

@staticmethod
def analyze(self):
def analyze(*args, **kwargs):
return AnalyzeResult(
schema=StructType()
.add("count", IntegerType())
Expand All @@ -2270,23 +2304,23 @@ def terminate(self):

self.spark.udtf.register("test_udtf", TestUDTF)

assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT count, total, last
FROM test_udtf({table_arg})
ORDER BY 1, 2
"""

for table_arg in ["TABLE(t)", "row => TABLE(t)"]:
with self.subTest(table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[Row(count=40, total=60, last=2)],
)
SELECT count, total, last
FROM test_udtf(TABLE(t))
ORDER BY 1, 2
"""
).collect(),
[
Row(count=40, total=60, last=2),
],
)

def test_udtf_with_table_argument_with_partition_by_and_order_by_from_analyze(self):
@udtf
Expand All @@ -2298,7 +2332,7 @@ def __init__(self):
self._last = None

@staticmethod
def analyze(self):
def analyze(*args, **kwargs):
return AnalyzeResult(
schema=StructType()
.add("partition_col", IntegerType())
Expand Down Expand Up @@ -2343,28 +2377,30 @@ def terminate(self):

self.spark.udtf.register("test_udtf", TestUDTF)

assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
UNION ALL
SELECT 42 AS partition_col, NULL AS input
UNION ALL
SELECT 42 AS partition_col, 1 AS input
UNION ALL
SELECT 42 AS partition_col, 2 AS input
base_query = """
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
UNION ALL
SELECT 42 AS partition_col, NULL AS input
UNION ALL
SELECT 42 AS partition_col, 1 AS input
UNION ALL
SELECT 42 AS partition_col, 2 AS input
)
SELECT partition_col, count, total, last
FROM test_udtf({table_arg})
ORDER BY 1, 2
"""

for table_arg in ["TABLE(t)", "row => TABLE(t)"]:
with self.subTest(table_arg=table_arg):
assertDataFrameEqual(
self.spark.sql(base_query.format(table_arg=table_arg)),
[Row(partition_col=x, count=2, total=3, last=2) for x in range(1, 21)]
+ [Row(partition_col=42, count=3, total=3, last=None)],
)
SELECT partition_col, count, total, last
FROM test_udtf(TABLE(t))
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, count=2, total=3, last=2) for x in range(1, 21)]
+ [Row(partition_col=42, count=3, total=3, last=None)],
)

def test_udtf_with_prepare_string_from_analyze(self):
@dataclass
Expand Down Expand Up @@ -2414,7 +2450,7 @@ def terminate(self):
SELECT total, buffer
FROM test_udtf("abc", TABLE(t))
"""
).collect(),
),
[Row(count=20, buffer="abc")],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// to apply the requested partitioning and/or ordering.
val analyzeResult = u.resolveElementMetadata(u.func, u.children)
val newChildren = u.children.map {
case NamedArgumentExpression(key, t: FunctionTableSubqueryArgumentExpression) =>
NamedArgumentExpression(key, analyzeResult.applyToTableArgument(u.name, t))
case t: FunctionTableSubqueryArgumentExpression =>
analyzeResult.applyToTableArgument(u.name, t)
case c => c
Expand Down