Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 76 additions & 66 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,33 +652,6 @@ def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types):
verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types)


def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
def wrapped(key_series, value_series):
import pandas as pd

value_df = pd.concat(value_series, axis=1)

if len(argspec.args) == 1:
result = f(value_df)
elif len(argspec.args) == 2:
# Extract key from pandas Series, preserving numpy types
key = tuple(s.iloc[0] for s in key_series)
result = f(key, value_df)

verify_pandas_result(
result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False
)

yield result

def flatten_wrapper(k, v):
# Return Iterator[[(df, spark_type)]] directly
for df in wrapped(k, v):
yield [(df, return_type)]

return flatten_wrapper


def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf):
def wrapped(key_series, value_batches):
import pandas as pd
Expand Down Expand Up @@ -1139,8 +1112,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
return func, None, None, None
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
num_udf_args = len(inspect.getfullargspec(chained_func).args)
return func, args_offsets, return_type, num_udf_args
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_grouped_map_pandas_iter_udf(
Expand Down Expand Up @@ -2393,6 +2366,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
):
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
if eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
Expand All @@ -2413,10 +2387,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
elif (
eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
):
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
ser = GroupPandasUDFSerializer(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up — after this PR, GroupPandasUDFSerializer is only used by SQL_GROUPED_MAP_PANDAS_ITER_UDF. The class comment at python/pyspark/sql/pandas/serializers.py:657 (# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF) is now stale. Worth updating in this PR or noting as a follow-up.

timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
Expand Down Expand Up @@ -2943,6 +2914,77 @@ def grouped_func(
# profiling is not supported for UDF
return grouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
import pyarrow as pa
import pandas as pd

assert num_udfs == 1, "One GROUPED_MAP_PANDAS UDF expected here."
grouped_udf, arg_offsets, return_type, num_udf_args = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)
assert len(parsed_offsets) == 1, "Expected one pair of offsets for GROUPED_MAP_PANDAS UDF."

key_offsets = parsed_offsets[0][0]
value_offsets = parsed_offsets[0][1]
output_schema = StructType([StructField("_0", return_type)])

def grouped_func(
split_index: int,
data: Iterator[Iterator[pa.RecordBatch]],
) -> Iterator[pa.RecordBatch]:
"""Apply groupBy Pandas UDF (non-iterator variant).

The explicit ``del`` calls below keep peakmem bounded across
groups. Without them, generator locals from the previous
iteration stay bound on the frame until each statement in
the next iteration rebinds its slot, so the input-side
DataFrames overlap with the next group's allocations and
the working set grows unbounded on wide-column, large-group
inputs. ``del result`` runs on resume from yield, before
``data.__next__()`` is asked for the next group.
"""
Comment on lines +2934 to +2944
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The peakmem rationale is plausible (input-side del before the convert call lowers peak vs. a helper-function approach), but the PR description benchmarks are wall-clock only. Could you add a peakmem comparison (e.g., tracemalloc.get_traced_memory() peaks or ASV's peakmem_* benchmarks) for the wide-column / large-group scenarios this docstring describes? That would close the loop on why the inline-with-del form was preferred over the simpler helper-function variant from the first commit, and protect against future edits inadvertently dropping a del and regressing peakmem.

for group in data:
all_batches = list(group)
if all_batches:
table = pa.Table.from_batches(all_batches).combine_chunks()
else:
table = pa.table({})
all_series = ArrowBatchTransformer.to_pandas(
table,
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
value_df = pd.concat([all_series[o] for o in value_offsets], axis=1)

if num_udf_args == 1:
result = grouped_udf(value_df)
else:
key = tuple(all_series[o].iloc[0] for o in key_offsets)
result = grouped_udf(key, value_df)

del all_batches, table, all_series, value_df

verify_pandas_result(
result,
return_type,
runner_conf.assign_cols_by_name,
truncate_return_schema=False,
)

yield PandasToArrowConversion.convert(
[result],
output_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OLD output path went through GroupPandasUDFSerializer.__init__, which omits prefers_large_types when calling super().__init__() and so defaulted to False. The new code respects spark.sql.execution.arrow.useLargeVarTypes. Since ArrowUtils.fromArrowSchema maps both Utf8 and LargeUtf8 to StringType (sql/api/.../ArrowUtils.scala:82,84), this is wire-format-only and user-invisible at the Spark type level — but the PR description's "No" to user-facing change is no longer strictly correct. Either:

  • Note in the PR description that this aligns with the Arrow analogue (SPARK-55608) and is intentional, or
  • Pass prefers_large_types=False here to preserve the exact pre-PR wire format.

Note also the resulting divergence with SQL_GROUPED_MAP_PANDAS_ITER_UDF, which still uses GroupPandasUDFSerializer with the hardcoded False default (worker.py:2391-2397). Once the iter variant migrates as part of SPARK-55388, this divergence resolves.

assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
del result

# profiling is not supported for UDF
return grouped_func, None, ser, ser

if (
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
Expand Down Expand Up @@ -3187,39 +3229,7 @@ def map_batch(batch):
# profiling is not supported for UDF
return func, None, ser, ser

if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
import pyarrow as pa

# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1

# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, f = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)

key_offsets = parsed_offsets[0][0]
value_offsets = parsed_offsets[0][1]

def mapper(batch_iter):
# Collect all Arrow batches and merge at Arrow level
all_batches = list(batch_iter)
if all_batches:
table = pa.Table.from_batches(all_batches).combine_chunks()
else:
table = pa.table({})
# Convert to pandas once for the entire group
all_series = ArrowBatchTransformer.to_pandas(
table,
timezone=ser._timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
key_series = [all_series[o] for o in key_offsets]
value_series = [all_series[o] for o in value_offsets]
yield from f(key_series, value_series)

elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
import pyarrow as pa

# We assume there is only one UDF here because grouped map doesn't
Expand Down