From edb004caf5280232f391aa240b4b7c3722c9c309 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 29 Apr 2026 08:12:00 +0000 Subject: [PATCH 1/2] refactor: extract SQL_SCALAR_PANDAS_UDF processing into read_udfs --- python/pyspark/worker.py | 123 ++++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 39 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 95a7ccdc4f8d..a2577841ee9c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -378,42 +378,6 @@ def wrap_udf(f, args_offsets, kwargs_offsets, return_type): return args_kwargs_offsets, lambda *a: func(*a) -def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): - func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) - - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series" - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - messageParameters={ - "udf_type": "pandas_udf", - "expected": str(length), - "actual": str(len(result)), - }, - ) - return result - - return ( - args_kwargs_offsets, - lambda *a: ( - verify_result_length(verify_result_type(func(*a)), len(a[0])), - return_type, - ), - ) - - def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): iter_type_label = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series" @@ -1125,7 +1089,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) + return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: @@ -2479,14 +2443,14 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, ): ser = ArrowStreamSerializer(write_start_stream=True) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) @@ -3122,6 +3086,87 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record # profiling is not supported for UDF return func, None, ser, ser + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + import pandas as pd + import pyarrow as pa + + # --- UDF preparation --- + udf_infos = [] + for udf_func, udf_args_offsets, udf_kwargs_offsets, udf_return_type in udfs: + wrapped_func, args_kwargs_offsets = wrap_kwargs_support( + udf_func, udf_args_offsets, udf_kwargs_offsets + ) + udf_infos.append((wrapped_func, args_kwargs_offsets, udf_return_type)) + col_names = [f"_{i}" for i in range(len(udfs))] + return_schema = StructType( + [StructField(name, info[2]) for name, info in zip(col_names, udf_infos)] + ) + + def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + for input_batch in batches: + num_rows = input_batch.num_rows + + # --- Input: Arrow -> pandas Series (struct columns become DataFrames) --- + pandas_columns = ArrowBatchTransformer.to_pandas( + input_batch, + timezone=runner_conf.timezone, + struct_in_pandas="dict", + ndarray_as_list=False, + prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, + df_for_struct=True, + ) + + # --- Process: evaluate each UDF column-wise on pandas Series --- + results = [] + for udf_func, offsets, udf_return_type in udf_infos: + result = udf_func(*[pandas_columns[o] for o in offsets]) + if not hasattr(result, "__len__"): + pd_type = ( + "pandas.DataFrame" + if isinstance(udf_return_type, StructType) + else "pandas.Series" + ) + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, + ) + if len(result) != num_rows: + raise PySparkRuntimeError( + errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + messageParameters={ + "udf_type": "pandas_udf", + "expected": str(num_rows), + "actual": str(len(result)), + }, + ) + # struct_in_pandas="dict": UDF must return DataFrame for struct types + if isinstance(udf_return_type, StructType) and not isinstance( + result, pd.DataFrame + ): + raise PySparkValueError( + "Invalid return type. Please make sure that the UDF returns a " + "pandas.DataFrame when the specified return type is StructType." + ) + results.append(result) + + # --- Output: pandas -> Arrow --- + yield PandasToArrowConversion.convert( + results, + return_schema, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + arrow_cast=True, + prefers_large_types=runner_conf.use_large_var_types, + assign_cols_by_name=runner_conf.assign_cols_by_name, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + + # profiling is not supported for UDF + return func, None, ser, ser + is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF From ee38c52bc2ef4c56943cc5907b3b780e9d5b52c4 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 30 Apr 2026 20:45:52 +0000 Subject: [PATCH 2/2] fix: rename func param to data to match other variants for mypy --- python/pyspark/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a2577841ee9c..65fdf6054fa4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -3102,8 +3102,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record [StructField(name, info[2]) for name, info in zip(col_names, udf_infos)] ) - def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: - for input_batch in batches: + def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + for input_batch in data: num_rows = input_batch.num_rows # --- Input: Arrow -> pandas Series (struct columns become DataFrames) ---