-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-56477][PYTHON] Refactor SQL_GROUPED_MAP_PANDAS_UDF #55495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
595cf1e
41dc496
29e79c7
cbe29b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -652,33 +652,6 @@ def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): | |
| verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): | ||
| def wrapped(key_series, value_series): | ||
| import pandas as pd | ||
|
|
||
| value_df = pd.concat(value_series, axis=1) | ||
|
|
||
| if len(argspec.args) == 1: | ||
| result = f(value_df) | ||
| elif len(argspec.args) == 2: | ||
| # Extract key from pandas Series, preserving numpy types | ||
| key = tuple(s.iloc[0] for s in key_series) | ||
| result = f(key, value_df) | ||
|
|
||
| verify_pandas_result( | ||
| result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False | ||
| ) | ||
|
|
||
| yield result | ||
|
|
||
| def flatten_wrapper(k, v): | ||
| # Return Iterator[[(df, spark_type)]] directly | ||
| for df in wrapped(k, v): | ||
| yield [(df, return_type)] | ||
|
|
||
| return flatten_wrapper | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): | ||
| def wrapped(key_series, value_batches): | ||
| import pandas as pd | ||
|
|
@@ -1139,8 +1112,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): | |
| elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: | ||
| return func, None, None, None | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: | ||
| argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it | ||
| return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) | ||
| num_udf_args = len(inspect.getfullargspec(chained_func).args) | ||
| return func, args_offsets, return_type, num_udf_args | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: | ||
| argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it | ||
| return args_offsets, wrap_grouped_map_pandas_iter_udf( | ||
|
|
@@ -2393,6 +2366,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): | |
| ): | ||
| # NOTE: if timezone is set here, that implies respectSessionTimeZone is True | ||
| if eval_type in ( | ||
| PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, | ||
| PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, | ||
| PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, | ||
| PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, | ||
|
|
@@ -2413,10 +2387,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): | |
| prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, | ||
| int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, | ||
| ) | ||
| elif ( | ||
| eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF | ||
| or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF | ||
| ): | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: | ||
| ser = GroupPandasUDFSerializer( | ||
| timezone=runner_conf.timezone, | ||
| safecheck=runner_conf.safecheck, | ||
|
|
@@ -2943,6 +2914,77 @@ def grouped_func( | |
| # profiling is not supported for UDF | ||
| return grouped_func, None, ser, ser | ||
|
|
||
| if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: | ||
| import pyarrow as pa | ||
| import pandas as pd | ||
|
|
||
| assert num_udfs == 1, "One GROUPED_MAP_PANDAS UDF expected here." | ||
| grouped_udf, arg_offsets, return_type, num_udf_args = udfs[0] | ||
| parsed_offsets = extract_key_value_indexes(arg_offsets) | ||
| assert len(parsed_offsets) == 1, "Expected one pair of offsets for GROUPED_MAP_PANDAS UDF." | ||
|
|
||
| key_offsets = parsed_offsets[0][0] | ||
| value_offsets = parsed_offsets[0][1] | ||
| output_schema = StructType([StructField("_0", return_type)]) | ||
|
|
||
| def grouped_func( | ||
| split_index: int, | ||
| data: Iterator[Iterator[pa.RecordBatch]], | ||
| ) -> Iterator[pa.RecordBatch]: | ||
| """Apply groupBy Pandas UDF (non-iterator variant). | ||
|
|
||
| The explicit ``del`` calls below keep peakmem bounded across | ||
| groups. Without them, generator locals from the previous | ||
| iteration stay bound on the frame until each statement in | ||
| the next iteration rebinds its slot, so the input-side | ||
| DataFrames overlap with the next group's allocations and | ||
| the working set grows unbounded on wide-column, large-group | ||
| inputs. ``del result`` runs on resume from yield, before | ||
| ``data.__next__()`` is asked for the next group. | ||
| """ | ||
|
Comment on lines
+2934
to
+2944
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The peakmem rationale is plausible (input-side |
||
| for group in data: | ||
| all_batches = list(group) | ||
| if all_batches: | ||
| table = pa.Table.from_batches(all_batches).combine_chunks() | ||
| else: | ||
| table = pa.table({}) | ||
| all_series = ArrowBatchTransformer.to_pandas( | ||
| table, | ||
| timezone=runner_conf.timezone, | ||
| prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, | ||
| ) | ||
| value_df = pd.concat([all_series[o] for o in value_offsets], axis=1) | ||
|
|
||
| if num_udf_args == 1: | ||
| result = grouped_udf(value_df) | ||
| else: | ||
| key = tuple(all_series[o].iloc[0] for o in key_offsets) | ||
| result = grouped_udf(key, value_df) | ||
|
|
||
| del all_batches, table, all_series, value_df | ||
|
|
||
| verify_pandas_result( | ||
| result, | ||
| return_type, | ||
| runner_conf.assign_cols_by_name, | ||
| truncate_return_schema=False, | ||
| ) | ||
|
|
||
| yield PandasToArrowConversion.convert( | ||
| [result], | ||
| output_schema, | ||
| timezone=runner_conf.timezone, | ||
| safecheck=runner_conf.safecheck, | ||
| arrow_cast=True, | ||
| prefers_large_types=runner_conf.use_large_var_types, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The OLD output path went through
Note also the resulting divergence with |
||
| assign_cols_by_name=runner_conf.assign_cols_by_name, | ||
| int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, | ||
| ) | ||
| del result | ||
|
|
||
| # profiling is not supported for UDF | ||
| return grouped_func, None, ser, ser | ||
|
|
||
| if ( | ||
| eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF | ||
| and not runner_conf.use_legacy_pandas_udf_conversion | ||
|
|
@@ -3187,39 +3229,7 @@ def map_batch(batch): | |
| # profiling is not supported for UDF | ||
| return func, None, ser, ser | ||
|
|
||
| if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: | ||
| import pyarrow as pa | ||
|
|
||
| # We assume there is only one UDF here because grouped map doesn't | ||
| # support combining multiple UDFs. | ||
| assert num_udfs == 1 | ||
|
|
||
| # See FlatMapGroupsInPandasExec for how arg_offsets are used to | ||
| # distinguish between grouping attributes and data attributes | ||
| arg_offsets, f = udfs[0] | ||
| parsed_offsets = extract_key_value_indexes(arg_offsets) | ||
|
|
||
| key_offsets = parsed_offsets[0][0] | ||
| value_offsets = parsed_offsets[0][1] | ||
|
|
||
| def mapper(batch_iter): | ||
| # Collect all Arrow batches and merge at Arrow level | ||
| all_batches = list(batch_iter) | ||
| if all_batches: | ||
| table = pa.Table.from_batches(all_batches).combine_chunks() | ||
| else: | ||
| table = pa.table({}) | ||
| # Convert to pandas once for the entire group | ||
| all_series = ArrowBatchTransformer.to_pandas( | ||
| table, | ||
| timezone=ser._timezone, | ||
| prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, | ||
| ) | ||
| key_series = [all_series[o] for o in key_offsets] | ||
| value_series = [all_series[o] for o in value_offsets] | ||
| yield from f(key_series, value_series) | ||
|
|
||
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: | ||
| if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: | ||
| import pyarrow as pa | ||
|
|
||
| # We assume there is only one UDF here because grouped map doesn't | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heads up — after this PR,
GroupPandasUDFSerializeris only used bySQL_GROUPED_MAP_PANDAS_ITER_UDF. The class comment atpython/pyspark/sql/pandas/serializers.py:657(# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF) is now stale. Worth updating in this PR or noting as a follow-up.