Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def apply_type_coercion():
if batch.num_columns == 0:
coerced_batch = batch # skip type coercion
else:
expected_field_names = arrow_return_type.names
expected_field_names = [field.name for field in arrow_return_type]
actual_field_names = batch.schema.names

if expected_field_names != actual_field_names:
Expand All @@ -283,7 +283,7 @@ def apply_type_coercion():
coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
coerced_batch = pa.RecordBatch.from_arrays(
coerced_arrays, names=arrow_return_type.names
coerced_arrays, names=expected_field_names
)
yield coerced_batch, arrow_return_type

Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,37 @@ def eval(self, input_val: int):
expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int")
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_with_named_arguments(self):
@arrow_udtf(returnType="x int, y int, sum int")
class NamedArgsUDTF:
def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]:
assert isinstance(x, pa.Array), f"Expected pa.Array, got {type(x)}"
assert isinstance(y, pa.Array), f"Expected pa.Array, got {type(y)}"

x_val = x[0].as_py()
y_val = y[0].as_py()
result_table = pa.table(
{
"x": pa.array([x_val], type=pa.int32()),
"y": pa.array([y_val], type=pa.int32()),
"sum": pa.array([x_val + y_val], type=pa.int32()),
}
)
yield result_table

# Test SQL registration and usage with named arguments
self.spark.udtf.register("named_args_udtf", NamedArgsUDTF)

# Test with named arguments in SQL
sql_result_df = self.spark.sql("SELECT * FROM named_args_udtf(y => 10, x => 5)")
expected_df = self.spark.createDataFrame([(5, 10, 15)], "x int, y int, sum int")
assertDataFrameEqual(sql_result_df, expected_df)

# Test with mixed positional and named arguments
sql_result_df2 = self.spark.sql("SELECT * FROM named_args_udtf(7, y => 3)")
expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int")
assertDataFrameEqual(sql_result_df2, expected_df2)


class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
pass
Expand Down