Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record
struct = batch.column(column_index)
return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))

@classmethod
def select_columns(cls, batch: "pa.RecordBatch", column_indices: list[int]) -> "pa.RecordBatch":
"""
Select a subset of columns from a RecordBatch by index.

Used by: SQL_COGROUPED_MAP_ARROW_UDF handler in worker.py
"""
import pyarrow as pa

return pa.RecordBatch.from_arrays(
[batch.columns[i] for i in column_indices],
[batch.schema.names[i] for i in column_indices],
)

@staticmethod
def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
"""
Expand Down
148 changes: 77 additions & 71 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupPandasUDFSerializer,
CogroupArrowUDFSerializer,
ArrowStreamCoGroupSerializer,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR lands, CogroupArrowUDFSerializer (python/pyspark/sql/pandas/serializers.py:704) loses its only user. Its parent ArrowStreamGroupUDFSerializer (line 301) was already orphaned by SPARK-55608 — its only remaining subclass was CogroupArrowUDFSerializer. After this PR, both classes are unreachable: no imports in worker.py, no other subclasses, no public re-exports in __init__.py. Suggest removing both in this PR (or as a small follow-up) so the dead code doesn't accumulate across the SPARK-55388 series.

CogroupPandasUDFSerializer,
ApplyInPandasWithStateSerializer,
TransformWithStateInPandasSerializer,
Expand Down Expand Up @@ -511,34 +511,6 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu
)


def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf):
if runner_conf.assign_cols_by_name:
expected_cols_and_types = {
col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields
}
else:
expected_cols_and_types = [
(col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields
]

def wrapped(left_key_table, left_value_table, right_key_table, right_value_table):
if len(argspec.args) == 2:
result = f(left_value_table, right_value_table)
elif len(argspec.args) == 3:
key_table = left_key_table if left_key_table.num_rows > 0 else right_key_table
key = tuple(c[0] for c in key_table.columns)
result = f(key, left_value_table, right_value_table)

verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types)

return result.to_batches()

return lambda kl, vl, kr, vr: (
wrapped(kl, vl, kr, vr),
to_arrow_type(return_type, timezone="UTC"),
)


def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
import pandas as pd
Expand Down Expand Up @@ -622,32 +594,31 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
)


def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types):
import pyarrow as pa

if not isinstance(table, pa.Table):
def verify_result_type(result, expected_type):
"""Check that a UDF result is the expected type, raising UDF_RETURN_TYPE if not."""
if not isinstance(result, expected_type):
package = getattr(inspect.getmodule(expected_type), "__package__", "")
label = f"{package}.{expected_type.__name__}" if package else expected_type.__name__
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": "pyarrow.Table",
"actual": type(table).__name__,
"expected": label,
"actual": type(result).__name__,
},
)


def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types):
import pyarrow as pa

verify_result_type(table, pa.Table)
verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types)


def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types):
import pyarrow as pa

if not isinstance(batch, pa.RecordBatch):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": "pyarrow.RecordBatch",
"actual": type(batch).__name__,
},
)
verify_result_type(batch, pa.RecordBatch)

verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types)

Expand Down Expand Up @@ -1173,7 +1144,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, argspec, runner_conf)
return func, args_offsets, return_type, len(argspec.args)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return wrap_grouped_agg_pandas_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf
Expand Down Expand Up @@ -2425,7 +2396,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = CogroupArrowUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name)
ser = ArrowStreamCoGroupSerializer(write_start_stream=True)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = CogroupPandasUDFSerializer(
timezone=runner_conf.timezone,
Expand Down Expand Up @@ -2943,6 +2914,67 @@ def grouped_func(
# profiling is not supported for UDF
return grouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
import pyarrow as pa

assert num_udfs == 1, "One COGROUPED_MAP_ARROW UDF expected here."
cogrouped_udf, arg_offsets, return_type, num_udf_args = udfs[0]

parsed_offsets = extract_key_value_indexes(arg_offsets)

# Pre-compute expected column names/types for strict result validation.
# Cogrouped map has a strict contract: missing, extra, or type-mismatched
# columns must raise; no silent coercion.
if runner_conf.assign_cols_by_name:
expected_cols_and_types = {
col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields
}
reorder_names = [col.name for col in return_type.fields]
else:
expected_cols_and_types = [
(col.name, to_arrow_type(col.dataType, timezone="UTC"))
for col in return_type.fields
]
reorder_names = None

select_columns = ArrowBatchTransformer.select_columns
left_key_cols, left_val_cols = parsed_offsets[0]
right_key_cols, right_val_cols = parsed_offsets[1]

def table_from_batches(batches, cols):
return pa.Table.from_batches([select_columns(b, cols) for b in batches])

def cogrouped_func(
split_index: int,
data: Iterator[Tuple[list[pa.RecordBatch], list[pa.RecordBatch]]],
) -> Iterator[pa.RecordBatch]:
Comment on lines +2947 to +2950
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a docstring to match the peer pattern. The Arrow analogue at worker.py:2802 has """Apply groupBy Arrow UDF (non-iterator variant).""", and the new grouped_func in #55495 has a longer one. Without it cogrouped_func reads a bit terse compared to its peers.

Suggested change
def cogrouped_func(
split_index: int,
data: Iterator[Tuple[list[pa.RecordBatch], list[pa.RecordBatch]]],
) -> Iterator[pa.RecordBatch]:
def cogrouped_func(
split_index: int,
data: Iterator[Tuple[list[pa.RecordBatch], list[pa.RecordBatch]]],
) -> Iterator[pa.RecordBatch]:
"""Apply cogroupBy Arrow UDF."""

for left_batches, right_batches in data:
left_keys = table_from_batches(left_batches, left_key_cols)
left_values = table_from_batches(left_batches, left_val_cols)
right_keys = table_from_batches(right_batches, right_key_cols)
right_values = table_from_batches(right_batches, right_val_cols)

if num_udf_args == 2:
result = cogrouped_udf(left_values, right_values)
else:
key_table = left_keys if left_keys.num_rows > 0 else right_keys
key = tuple(c[0] for c in key_table.columns)
result = cogrouped_udf(key, left_values, right_values)

verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types)

for batch in result.to_batches():
if reorder_names is not None:
# Names and types already validated equal; pure reorder, no cast.
batch = pa.RecordBatch.from_arrays(
[batch.column(name) for name in reorder_names],
names=reorder_names,
)
yield ArrowBatchTransformer.wrap_struct(batch)

# profiling is not supported for UDF
return cogrouped_func, None, ser, ser

if (
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
Expand Down Expand Up @@ -3431,32 +3463,6 @@ def mapper(a):
df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
return f(df1_keys, df1_vals, df2_keys, df2_vals)

elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
import pyarrow as pa

# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
arg_offsets, f = udfs[0]

parsed_offsets = extract_key_value_indexes(arg_offsets)

def batch_from_offset(batch, offsets):
return pa.RecordBatch.from_arrays(
arrays=[batch.columns[o] for o in offsets],
names=[batch.schema.names[o] for o in offsets],
)

def table_from_batches(batches, offsets):
return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches])

def mapper(a):
df1_keys = table_from_batches(a[0], parsed_offsets[0][0])
df1_vals = table_from_batches(a[0], parsed_offsets[0][1])
df2_keys = table_from_batches(a[1], parsed_offsets[1][0])
df2_vals = table_from_batches(a[1], parsed_offsets[1][1])
return f(df1_keys, df1_vals, df2_keys, df2_vals)

elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
# We assume there is only one UDF here because grouped agg doesn't
# support combining multiple UDFs.
Expand Down