Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 82 additions & 83 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
ArrowStreamGroupSerializer,
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupPandasUDFSerializer,
CogroupArrowUDFSerializer,
CogroupPandasUDFSerializer,
ApplyInPandasWithStateSerializer,
Expand Down Expand Up @@ -652,39 +651,6 @@ def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types):
verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types)


def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf):
def wrapped(key_series, value_batches):
import pandas as pd

# value_batches is an Iterator[list[pd.Series]] (one list per batch)
# Convert each list of Series into a DataFrame
def dataframe_iter():
for value_series in value_batches:
yield pd.concat(value_series, axis=1)

if len(argspec.args) == 1:
result = f(dataframe_iter())
elif len(argspec.args) == 2:
# Extract key from pandas Series, preserving numpy types
key = tuple(s.iloc[0] for s in key_series)
result = f(key, dataframe_iter())

def verify_element(df):
verify_pandas_result(
df, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False
)
return df

yield from map(verify_element, result)

def flatten_wrapper(k, v):
# Return Iterator[[(df, spark_type)]] directly
for df in wrapped(k, v):
yield [(df, return_type)]

return flatten_wrapper


def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
result_iter = f(stateful_processor_api_client, mode, key, value_series_gen)
Expand Down Expand Up @@ -1122,14 +1088,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
return func, None, None, None
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
):
num_udf_args = len(inspect.getfullargspec(chained_func).args)
return func, args_offsets, return_type, num_udf_args
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_grouped_map_pandas_iter_udf(
func, return_type, argspec, runner_conf
)
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
Expand Down Expand Up @@ -2378,6 +2342,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
if eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
Expand All @@ -2398,14 +2363,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
ser = GroupPandasUDFSerializer(
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
assign_cols_by_name=runner_conf.assign_cols_by_name,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = CogroupArrowUDFSerializer(assign_cols_by_name=runner_conf.assign_cols_by_name)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
Expand Down Expand Up @@ -2996,6 +2953,82 @@ def grouped_func(
# profiling is not supported for UDF
return grouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
import pyarrow as pa
import pandas as pd

assert num_udfs == 1, "One GROUPED_MAP_PANDAS_ITER UDF expected here."
grouped_udf, arg_offsets, return_type, num_udf_args = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)
assert len(parsed_offsets) == 1, (
"Expected one pair of offsets for GROUPED_MAP_PANDAS_ITER UDF."
)

key_offsets = parsed_offsets[0][0]
value_offsets = parsed_offsets[0][1]
output_schema = StructType([StructField("_0", return_type)])

def grouped_func(
split_index: int,
data: Iterator[Iterator[pa.RecordBatch]],
) -> Iterator[pa.RecordBatch]:
"""Apply groupBy Pandas UDF (iterator variant).

The UDF receives an Iterator[pd.DataFrame] per group and
returns an Iterator[pd.DataFrame]. Input batches are
converted to pandas lazily so peakmem stays bounded by a
single batch rather than the whole group.
"""
for group in data:
group_iter = iter(group)
# Read the first batch to extract grouping keys.
first_series = ArrowBatchTransformer.to_pandas(
next(group_iter),
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)

def dataframe_iter():
yield pd.concat([first_series[o] for o in value_offsets], axis=1)
for batch in group_iter:
series = ArrowBatchTransformer.to_pandas(
batch,
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
yield pd.concat([series[o] for o in value_offsets], axis=1)

if num_udf_args == 1:
result = grouped_udf(dataframe_iter())
else:
key = tuple(first_series[o].iloc[0] for o in key_offsets)
result = grouped_udf(key, dataframe_iter())

for df in result:
verify_pandas_result(
df,
return_type,
runner_conf.assign_cols_by_name,
truncate_return_schema=False,
)
yield PandasToArrowConversion.convert(
[df],
output_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)

# Drain remaining input batches to maintain stream position.
for _ in group_iter:
pass

# profiling is not supported for UDF
return grouped_func, None, ser, ser

if (
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
Expand Down Expand Up @@ -3240,41 +3273,7 @@ def map_batch(batch):
# profiling is not supported for UDF
return func, None, ser, ser

if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
import pyarrow as pa

# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1

# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, f = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)

def mapper(batch_iter):
# Convert first Arrow batch to pandas to extract keys
first_series = ArrowBatchTransformer.to_pandas(
next(batch_iter),
timezone=ser._timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
key_series = [first_series[o] for o in parsed_offsets[0][0]]

# Lazily convert remaining Arrow batches to pandas Series
def value_series_gen():
yield [first_series[o] for o in parsed_offsets[0][1]]
for batch in batch_iter:
series = ArrowBatchTransformer.to_pandas(
batch,
timezone=ser._timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
yield [series[o] for o in parsed_offsets[0][1]]

yield from f(key_series, value_series_gen())

elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
if eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
Expand Down