diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 8fc4fa5cc0cc7..a229386f30010 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -90,6 +90,20 @@ def flatten_struct(batch: "pa.RecordBatch", column_index: int = 0) -> "pa.Record struct = batch.column(column_index) return pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type)) + @classmethod + def select_columns(cls, batch: "pa.RecordBatch", column_indices: list[int]) -> "pa.RecordBatch": + """ + Select a subset of columns from a RecordBatch by index. + + Used by: SQL_COGROUPED_MAP_ARROW_UDF handler in worker.py + """ + import pyarrow as pa + + return pa.RecordBatch.from_arrays( + [batch.columns[i] for i in column_indices], + [batch.schema.names[i] for i in column_indices], + ) + @staticmethod def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bb81ae044ea6..8802e1e9d7291 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -63,7 +63,7 @@ ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, GroupPandasUDFSerializer, - CogroupArrowUDFSerializer, + ArrowStreamCoGroupSerializer, CogroupPandasUDFSerializer, ApplyInPandasWithStateSerializer, TransformWithStateInPandasSerializer, @@ -511,34 +511,6 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu ) -def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): - if runner_conf.assign_cols_by_name: - expected_cols_and_types = { - col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields - } - else: - expected_cols_and_types = [ - (col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields - ] - - def wrapped(left_key_table, left_value_table, right_key_table, right_value_table): - if len(argspec.args) == 2: - result = f(left_value_table, right_value_table) - elif len(argspec.args) == 3: - key_table = left_key_table if left_key_table.num_rows > 0 else right_key_table - key = tuple(c[0] for c in key_table.columns) - result = f(key, left_value_table, right_value_table) - - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) - - return result.to_batches() - - return lambda kl, vl, kr, vr: ( - wrapped(kl, vl, kr, vr), - to_arrow_type(return_type, timezone="UTC"), - ) - - def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): import pandas as pd @@ -622,32 +594,31 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): ) -def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - if not isinstance(table, pa.Table): +def verify_result_type(result, expected_type): + """Check that a UDF result is the expected type, raising UDF_RETURN_TYPE if not.""" + if not isinstance(result, expected_type): + package = getattr(inspect.getmodule(expected_type), "__package__", "") + label = f"{package}.{expected_type.__name__}" if package else expected_type.__name__ raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ - "expected": "pyarrow.Table", - "actual": type(table).__name__, + "expected": label, + "actual": type(result).__name__, }, ) + +def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): + import pyarrow as pa + + verify_result_type(table, pa.Table) verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): import pyarrow as pa - if not isinstance(batch, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.RecordBatch", - "actual": type(batch).__name__, - }, - ) + verify_result_type(batch, pa.RecordBatch) verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) @@ -1173,7 +1144,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it - return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, argspec, runner_conf) + return func, args_offsets, return_type, len(argspec.args) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return wrap_grouped_agg_pandas_udf( func, args_offsets, kwargs_offsets, return_type, runner_conf @@ -2425,7 +2396,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: - ser = CogroupArrowUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name) + ser = ArrowStreamCoGroupSerializer(write_start_stream=True) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupPandasUDFSerializer( timezone=runner_conf.timezone, @@ -2943,6 +2914,67 @@ def grouped_func( # profiling is not supported for UDF return grouped_func, None, ser, ser + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: + import pyarrow as pa + + assert num_udfs == 1, "One COGROUPED_MAP_ARROW UDF expected here." + cogrouped_udf, arg_offsets, return_type, num_udf_args = udfs[0] + + parsed_offsets = extract_key_value_indexes(arg_offsets) + + # Pre-compute expected column names/types for strict result validation. + # Cogrouped map has a strict contract: missing, extra, or type-mismatched + # columns must raise; no silent coercion. + if runner_conf.assign_cols_by_name: + expected_cols_and_types = { + col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields + } + reorder_names = [col.name for col in return_type.fields] + else: + expected_cols_and_types = [ + (col.name, to_arrow_type(col.dataType, timezone="UTC")) + for col in return_type.fields + ] + reorder_names = None + + select_columns = ArrowBatchTransformer.select_columns + left_key_cols, left_val_cols = parsed_offsets[0] + right_key_cols, right_val_cols = parsed_offsets[1] + + def table_from_batches(batches, cols): + return pa.Table.from_batches([select_columns(b, cols) for b in batches]) + + def cogrouped_func( + split_index: int, + data: Iterator[Tuple[list[pa.RecordBatch], list[pa.RecordBatch]]], + ) -> Iterator[pa.RecordBatch]: + for left_batches, right_batches in data: + left_keys = table_from_batches(left_batches, left_key_cols) + left_values = table_from_batches(left_batches, left_val_cols) + right_keys = table_from_batches(right_batches, right_key_cols) + right_values = table_from_batches(right_batches, right_val_cols) + + if num_udf_args == 2: + result = cogrouped_udf(left_values, right_values) + else: + key_table = left_keys if left_keys.num_rows > 0 else right_keys + key = tuple(c[0] for c in key_table.columns) + result = cogrouped_udf(key, left_values, right_values) + + verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + + for batch in result.to_batches(): + if reorder_names is not None: + # Names and types already validated equal; pure reorder, no cast. + batch = pa.RecordBatch.from_arrays( + [batch.column(name) for name in reorder_names], + names=reorder_names, + ) + yield ArrowBatchTransformer.wrap_struct(batch) + + # profiling is not supported for UDF + return cogrouped_func, None, ser, ser + if ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion @@ -3431,32 +3463,6 @@ def mapper(a): df2_vals = [a[1][o] for o in parsed_offsets[1][1]] return f(df1_keys, df1_vals, df2_keys, df2_vals) - elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: - import pyarrow as pa - - # We assume there is only one UDF here because cogrouped map doesn't - # support combining multiple UDFs. - assert num_udfs == 1 - arg_offsets, f = udfs[0] - - parsed_offsets = extract_key_value_indexes(arg_offsets) - - def batch_from_offset(batch, offsets): - return pa.RecordBatch.from_arrays( - arrays=[batch.columns[o] for o in offsets], - names=[batch.schema.names[o] for o in offsets], - ) - - def table_from_batches(batches, offsets): - return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) - - def mapper(a): - df1_keys = table_from_batches(a[0], parsed_offsets[0][0]) - df1_vals = table_from_batches(a[0], parsed_offsets[0][1]) - df2_keys = table_from_batches(a[1], parsed_offsets[1][0]) - df2_vals = table_from_batches(a[1], parsed_offsets[1][1]) - return f(df1_keys, df1_vals, df2_keys, df2_vals) - elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF: # We assume there is only one UDF here because grouped agg doesn't # support combining multiple UDFs.