Skip to content

Commit

Permalink
[SPARK-48142][PYTHON][CONNECT][TESTS] Enable `CogroupedApplyInPandasT…
Browse files Browse the repository at this point in the history
…ests.test_wrong_args`

### What changes were proposed in this pull request?
Enable `CogroupedApplyInPandasTests.test_wrong_args` by including a missing check

### Why are the changes needed?
for test coverage

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #46397 from zhengruifeng/fix_pandas_udf_check.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed May 7, 2024
1 parent 7290000 commit 2ef7246
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def applyInPandas(
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]

_validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
return _create_pandas_udf(f=f, returnType=return_type, evalType=eval_type)


def _create_pandas_udf(f, returnType, evalType):
# validate the pandas udf and return the adjusted eval type
def _validate_pandas_udf(f, returnType, evalType) -> int:
argspec = getfullargspec(f)

# pandas UDF by type hints.
Expand Down Expand Up @@ -528,6 +529,12 @@ def _create_pandas_udf(f, returnType, evalType):
},
)

return evalType


def _create_pandas_udf(f, returnType, evalType):
evalType = _validate_pandas_udf(f, returnType, evalType)

if is_remote():
from pyspark.sql.connect.udf import _create_udf as _create_connect_udf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from pyspark.testing.connectutils import ReusedConnectTestCase


class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedConnectTestCase):
@unittest.skip("Fails in Spark Connect, should enable.")
def test_wrong_args(self):
self.check_wrong_args()
class CogroupedApplyInPandasTests(
CogroupedApplyInPandasTestsMixin,
ReusedConnectTestCase,
):
pass


if __name__ == "__main__":
Expand Down

0 comments on commit 2ef7246

Please sign in to comment.