diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 32ea541711848..2e38838b10288 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -62,7 +62,6 @@ ArrowStreamGroupSerializer, ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, - GroupPandasUDFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, ApplyInPandasWithStateSerializer, @@ -652,39 +651,6 @@ def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) -def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): - def wrapped(key_series, value_batches): - import pandas as pd - - # value_batches is an Iterator[list[pd.Series]] (one list per batch) - # Convert each list of Series into a DataFrame - def dataframe_iter(): - for value_series in value_batches: - yield pd.concat(value_series, axis=1) - - if len(argspec.args) == 1: - result = f(dataframe_iter()) - elif len(argspec.args) == 2: - # Extract key from pandas Series, preserving numpy types - key = tuple(s.iloc[0] for s in key_series) - result = f(key, dataframe_iter()) - - def verify_element(df): - verify_pandas_result( - df, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False - ) - return df - - yield from map(verify_element, result) - - def flatten_wrapper(k, v): - # Return Iterator[[(df, spark_type)]] directly - for df in wrapped(k, v): - yield [(df, return_type)] - - return flatten_wrapper - - def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, value_series_gen): result_iter = f(stateful_processor_api_client, mode, key, value_series_gen) @@ -1122,14 +1088,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: return func, None, None, None - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, + ): num_udf_args = len(inspect.getfullargspec(chained_func).args) return func, args_offsets, return_type, num_udf_args - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: - argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it - return args_offsets, wrap_grouped_map_pandas_iter_udf( - func, return_type, argspec, runner_conf - ) elif eval_type in ( PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, @@ -2378,6 +2342,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): # NOTE: if timezone is set here, that implies respectSessionTimeZone is True if eval_type in ( PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, @@ -2398,14 +2363,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: - ser = GroupPandasUDFSerializer( - timezone=runner_conf.timezone, - safecheck=runner_conf.safecheck, - assign_cols_by_name=runner_conf.assign_cols_by_name, - prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, - int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, - ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: ser = CogroupArrowUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: @@ -2996,6 +2953,82 @@ def grouped_func( # profiling is not supported for UDF return grouped_func, None, ser, ser + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: + import pyarrow as pa + import pandas as pd + + assert num_udfs == 1, "One GROUPED_MAP_PANDAS_ITER UDF expected here." + grouped_udf, arg_offsets, return_type, num_udf_args = udfs[0] + parsed_offsets = extract_key_value_indexes(arg_offsets) + assert len(parsed_offsets) == 1, ( + "Expected one pair of offsets for GROUPED_MAP_PANDAS_ITER UDF." + ) + + key_offsets = parsed_offsets[0][0] + value_offsets = parsed_offsets[0][1] + output_schema = StructType([StructField("_0", return_type)]) + + def grouped_func( + split_index: int, + data: Iterator[Iterator[pa.RecordBatch]], + ) -> Iterator[pa.RecordBatch]: + """Apply groupBy Pandas UDF (iterator variant). + + The UDF receives an Iterator[pd.DataFrame] per group and + returns an Iterator[pd.DataFrame]. Input batches are + converted to pandas lazily so peakmem stays bounded by a + single batch rather than the whole group. + """ + for group in data: + group_iter = iter(group) + # Read the first batch to extract grouping keys. + first_series = ArrowBatchTransformer.to_pandas( + next(group_iter), + timezone=runner_conf.timezone, + prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, + ) + + def dataframe_iter(): + yield pd.concat([first_series[o] for o in value_offsets], axis=1) + for batch in group_iter: + series = ArrowBatchTransformer.to_pandas( + batch, + timezone=runner_conf.timezone, + prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, + ) + yield pd.concat([series[o] for o in value_offsets], axis=1) + + if num_udf_args == 1: + result = grouped_udf(dataframe_iter()) + else: + key = tuple(first_series[o].iloc[0] for o in key_offsets) + result = grouped_udf(key, dataframe_iter()) + + for df in result: + verify_pandas_result( + df, + return_type, + runner_conf.assign_cols_by_name, + truncate_return_schema=False, + ) + yield PandasToArrowConversion.convert( + [df], + output_schema, + timezone=runner_conf.timezone, + safecheck=runner_conf.safecheck, + arrow_cast=True, + prefers_large_types=runner_conf.use_large_var_types, + assign_cols_by_name=runner_conf.assign_cols_by_name, + int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, + ) + + # Drain remaining input batches to maintain stream position. + for _ in group_iter: + pass + + # profiling is not supported for UDF + return grouped_func, None, ser, ser + if ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not runner_conf.use_legacy_pandas_udf_conversion @@ -3240,41 +3273,7 @@ def map_batch(batch): # profiling is not supported for UDF return func, None, ser, ser - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: - import pyarrow as pa - - # We assume there is only one UDF here because grouped map doesn't - # support combining multiple UDFs. - assert num_udfs == 1 - - # See FlatMapGroupsInPandasExec for how arg_offsets are used to - # distinguish between grouping attributes and data attributes - arg_offsets, f = udfs[0] - parsed_offsets = extract_key_value_indexes(arg_offsets) - - def mapper(batch_iter): - # Convert first Arrow batch to pandas to extract keys - first_series = ArrowBatchTransformer.to_pandas( - next(batch_iter), - timezone=ser._timezone, - prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, - ) - key_series = [first_series[o] for o in parsed_offsets[0][0]] - - # Lazily convert remaining Arrow batches to pandas Series - def value_series_gen(): - yield [first_series[o] for o in parsed_offsets[0][1]] - for batch in batch_iter: - series = ArrowBatchTransformer.to_pandas( - batch, - timezone=ser._timezone, - prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, - ) - yield [series[o] for o in parsed_offsets[0][1]] - - yield from f(key_series, value_series_gen()) - - elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1