diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 699dce76c4a15..c916e8acf3e43 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -398,7 +398,9 @@ def applyInPandas( ) -> "DataFrame": from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] + _validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) udf_obj = UserDefinedFunction( func, returnType=schema, diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 62d365a3b2a1d..5922a5ced8639 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -431,7 +431,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: return _create_pandas_udf(f=f, returnType=return_type, evalType=eval_type) -def _create_pandas_udf(f, returnType, evalType): +# validate the pandas udf and return the adjusted eval type +def _validate_pandas_udf(f, returnType, evalType) -> int: argspec = getfullargspec(f) # pandas UDF by type hints. @@ -528,6 +529,12 @@ def _create_pandas_udf(f, returnType, evalType): }, ) + return evalType + + +def _create_pandas_udf(f, returnType, evalType): + evalType = _validate_pandas_udf(f, returnType, evalType) + if is_remote(): from pyspark.sql.connect.udf import _create_udf as _create_connect_udf diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py index 708960dd47d41..00d71bda2d938 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py @@ -20,10 +20,11 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedConnectTestCase): - @unittest.skip("Fails in Spark Connect, should enable.") - def test_wrong_args(self): - self.check_wrong_args() +class CogroupedApplyInPandasTests( + CogroupedApplyInPandasTestsMixin, + ReusedConnectTestCase, +): + pass if __name__ == "__main__":