Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 84 additions & 39 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,42 +378,6 @@ def wrap_udf(f, args_offsets, kwargs_offsets, return_type):
return args_kwargs_offsets, lambda *a: func(*a)


def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)

def verify_result_type(result):
if not hasattr(result, "__len__"):
pd_type = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series"
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
return result

def verify_result_length(result, length):
if len(result) != length:
raise PySparkRuntimeError(
errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
messageParameters={
"udf_type": "pandas_udf",
"expected": str(length),
"actual": str(len(result)),
},
)
return result

return (
args_kwargs_offsets,
lambda *a: (
verify_result_length(verify_result_type(func(*a)), len(a[0])),
return_type,
),
)


def wrap_pandas_batch_iter_udf(f, return_type, runner_conf):
iter_type_label = "pandas.DataFrame" if isinstance(return_type, StructType) else "pandas.Series"

Expand Down Expand Up @@ -1125,7 +1089,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):

# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf)
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
Expand Down Expand Up @@ -2479,14 +2443,14 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
):
ser = ArrowStreamSerializer(write_start_stream=True)
else:
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
)

Expand Down Expand Up @@ -3122,6 +3086,87 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
# profiling is not supported for UDF
return func, None, ser, ser

if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
import pandas as pd
import pyarrow as pa

# --- UDF preparation ---
udf_infos = []
for udf_func, udf_args_offsets, udf_kwargs_offsets, udf_return_type in udfs:
wrapped_func, args_kwargs_offsets = wrap_kwargs_support(
udf_func, udf_args_offsets, udf_kwargs_offsets
)
udf_infos.append((wrapped_func, args_kwargs_offsets, udf_return_type))
col_names = [f"_{i}" for i in range(len(udfs))]
return_schema = StructType(
[StructField(name, info[2]) for name, info in zip(col_names, udf_infos)]
)

def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
for input_batch in data:
num_rows = input_batch.num_rows

# --- Input: Arrow -> pandas Series (struct columns become DataFrames) ---
pandas_columns = ArrowBatchTransformer.to_pandas(
input_batch,
timezone=runner_conf.timezone,
struct_in_pandas="dict",
ndarray_as_list=False,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
df_for_struct=True,
)

# --- Process: evaluate each UDF column-wise on pandas Series ---
results = []
for udf_func, offsets, udf_return_type in udf_infos:
result = udf_func(*[pandas_columns[o] for o in offsets])
if not hasattr(result, "__len__"):
pd_type = (
"pandas.DataFrame"
if isinstance(udf_return_type, StructType)
else "pandas.Series"
)
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
if len(result) != num_rows:
raise PySparkRuntimeError(
errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
messageParameters={
"udf_type": "pandas_udf",
"expected": str(num_rows),
"actual": str(len(result)),
},
)
# struct_in_pandas="dict": UDF must return DataFrame for struct types
if isinstance(udf_return_type, StructType) and not isinstance(
result, pd.DataFrame
):
raise PySparkValueError(
"Invalid return type. Please make sure that the UDF returns a "
"pandas.DataFrame when the specified return type is StructType."
)
results.append(result)

# --- Output: pandas -> Arrow ---
yield PandasToArrowConversion.convert(
results,
return_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)

# profiling is not supported for UDF
return func, None, ser, ser

is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

Expand Down