Skip to content

[SPARK-52307][PYTHON][CONNECT] Support Scalar Arrow Iterator UDF #51018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private[spark] object PythonEvalType {

// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
val SQL_SCALAR_ARROW_ITER_UDF = 251

val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
Expand Down Expand Up @@ -96,7 +97,10 @@ private[spark] object PythonEvalType {
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"

// Arrow UDFs
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def register(
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
]:
raise PySparkTypeError(
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]

# Arrow UDFs
ArrowScalarUDFType = Literal[250]
ArrowScalarIterUDFType = Literal[251]

class ArrowVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...
Expand Down Expand Up @@ -135,6 +136,11 @@ ArrowScalarToScalarFunction = Union[
],
]

ArrowScalarIterFunction = Union[
Callable[[Iterable[pyarrow.Array]], Iterable[pyarrow.Array]],
Callable[[Tuple[pyarrow.Array, ...]], Iterable[pyarrow.Array]],
]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...

Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ArrowUDFType:

SCALAR = PythonEvalType.SQL_SCALAR_ARROW_UDF

SCALAR_ITER = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF


def arrow_udf(f=None, returnType=None, functionType=None):
return vectorized_udf(f, returnType, functionType, "arrow")
Expand Down Expand Up @@ -451,6 +453,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
)
if kind == "arrow" and eval_type not in [
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
None,
]: # None means it should infer the type from type hints.
raise PySparkTypeError(
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/pandas/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ from pyspark.sql.pandas._typing import (
PandasScalarUDFType,
ArrowScalarToScalarFunction,
ArrowScalarUDFType,
ArrowScalarIterFunction,
ArrowScalarIterUDFType,
)

from pyspark import since as since # noqa: F401
Expand All @@ -51,6 +53,7 @@ class PandasUDFType:

class ArrowUDFType:
SCALAR: ArrowScalarUDFType
SCALAR_ITER: ArrowScalarIterUDFType

@overload
def arrow_udf(
Expand All @@ -71,6 +74,24 @@ def arrow_udf(
*, returnType: DataTypeOrString, functionType: ArrowScalarUDFType
) -> Callable[[ArrowScalarToScalarFunction], UserDefinedFunctionLike]: ...
@overload
def arrow_udf(
f: ArrowScalarIterFunction,
returnType: Union[AtomicDataTypeOrString, ArrayType],
functionType: ArrowScalarIterUDFType,
) -> UserDefinedFunctionLike: ...
@overload
def arrow_udf(
f: Union[AtomicDataTypeOrString, ArrayType], returnType: ArrowScalarIterUDFType
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
@overload
def arrow_udf(
*, returnType: Union[AtomicDataTypeOrString, ArrayType], functionType: ArrowScalarIterUDFType
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
@overload
def arrow_udf(
f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: ArrowScalarIterUDFType
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
@overload
def pandas_udf(
f: PandasScalarToScalarFunction,
returnType: Union[AtomicDataTypeOrString, ArrayType],
Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/sql/pandas/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PandasScalarIterUDFType,
PandasGroupedAggUDFType,
ArrowScalarUDFType,
ArrowScalarIterUDFType,
)


Expand All @@ -36,6 +37,7 @@ def infer_eval_type(
"PandasScalarIterUDFType",
"PandasGroupedAggUDFType",
"ArrowScalarUDFType",
"ArrowScalarIterUDFType",
]:
"""
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
Expand Down Expand Up @@ -110,6 +112,21 @@ def infer_eval_type(
)
)

# Iterator[Tuple[pa.Array, ...] -> Iterator[pa.Array]
is_iterator_tuple_array = (
len(parameters_sig) == 1
and check_iterator_annotation( # Iterator
parameters_sig[0],
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
a,
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
),
)
and check_iterator_annotation(
return_annotation, parameter_check_func=lambda a: a == pa.Array
)
)

# Iterator[Series, Frame or Union[DataFrame, Series]] -> Iterator[Series or Frame]
is_iterator_series_or_frame = (
len(parameters_sig) == 1
Expand All @@ -128,6 +145,18 @@ def infer_eval_type(
)
)

# Iterator[pa.Array] -> Iterator[pa.Array]
is_iterator_array = (
len(parameters_sig) == 1
and check_iterator_annotation(
parameters_sig[0],
parameter_check_func=lambda a: (a == pd.Series or a == pa.Array),
)
and check_iterator_annotation(
return_annotation, parameter_check_func=lambda a: a == pa.Array
)
)

# Series, Frame or Union[DataFrame, Series], ... -> Any
is_series_or_frame_agg = all(
a == pd.Series
Expand All @@ -152,6 +181,8 @@ def infer_eval_type(
return ArrowUDFType.SCALAR
elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame:
return PandasUDFType.SCALAR_ITER
elif is_iterator_tuple_array or is_iterator_array:
return ArrowUDFType.SCALAR_ITER
elif is_series_or_frame_agg:
return PandasUDFType.GROUPED_AGG
else:
Expand Down
Loading