Skip to content

Commit

Permalink
[SPARK-48228][PYTHON][CONNECT] Implement the missing function validat…
Browse files Browse the repository at this point in the history
…ion in ApplyInXXX

### What changes were proposed in this pull request?
Implement the missing function validation in ApplyInXXX

apache#46397 fixed this issue for `Cogrouped.ApplyInPandas`, this PR fix remaining methods.

### Why are the changes needed?
for better error message:

```
In [12]: df1 = spark.range(11)

In [13]: df2 = df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))

In [14]: df2.show()
```

before this PR, an invalid function causes weird execution errors:
```
24/05/10 11:37:36 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 36)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1834, in main
    process()
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1826, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 531, in dump_stream
    return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 104, in dump_stream
    for batch in iterator:
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 524, in init_stream_yield_batches
    for series in iterator:
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 1610, in mapper
    return f(keys, vals)
           ^^^^^^^^^^^^^
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 488, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
                          ^^^^^^^^^^^^^
  File "/Users/ruifeng.zheng/Dev/spark/python/lib/pyspark.zip/pyspark/worker.py", line 483, in wrapped
    result, return_type, _assign_cols_by_name, truncate_return_schema=False
    ^^^^^^
UnboundLocalError: cannot access local variable 'result' where it is not associated with a value

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:523)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:479)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
	at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:896)

	...
```

After this PR, the error happens before execution, which is consistent with Spark Classic, and
 much clear
```
PySparkValueError: [INVALID_PANDAS_UDF] Invalid function: pandas_udf with function type GROUPED_MAP or the function in groupby.applyInPandas must take either one argument (data) or two arguments (key, data).

```

### Does this PR introduce _any_ user-facing change?
yes, error message changes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#46519 from zhengruifeng/missing_check_in_group.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
zhengruifeng authored and JacobZheng0927 committed May 11, 2024
1 parent 9fa10df commit c5da268
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
8 changes: 6 additions & 2 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pyspark.util import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
from pyspark.sql.types import NumericType
from pyspark.sql.types import StructType

Expand Down Expand Up @@ -293,6 +294,7 @@ def applyInPandas(
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
Expand Down Expand Up @@ -322,6 +324,7 @@ def applyInPandasWithState(
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE)
udf_obj = UserDefinedFunction(
func,
returnType=outputStructType,
Expand Down Expand Up @@ -360,6 +363,7 @@ def applyInArrow(
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

_validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
Expand Down Expand Up @@ -398,9 +402,8 @@ def applyInPandas(
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]

_validate_pandas_udf(func, schema, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
Expand All @@ -426,6 +429,7 @@ def applyInArrow(
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

_validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF)
udf_obj = UserDefinedFunction(
func,
returnType=schema,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:


# validate the pandas udf and return the adjusted eval type
def _validate_pandas_udf(f, returnType, evalType) -> int:
def _validate_pandas_udf(f, evalType) -> int:
argspec = getfullargspec(f)

# pandas UDF by type hints.
Expand Down Expand Up @@ -533,7 +533,7 @@ def _validate_pandas_udf(f, returnType, evalType) -> int:


def _create_pandas_udf(f, returnType, evalType):
evalType = _validate_pandas_udf(f, returnType, evalType)
evalType = _validate_pandas_udf(f, evalType)

if is_remote():
from pyspark.sql.connect.udf import _create_udf as _create_connect_udf
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,26 @@ def check_wrong_args(self):
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))
)

def test_wrong_args_in_apply_func(self):
df1 = self.spark.range(11)
df2 = self.spark.range(22)

with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))

with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").applyInArrow(lambda: 1, StructType([StructField("d", DoubleType())]))

with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
lambda: 1, StructType([StructField("d", DoubleType())])
)

with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
lambda: 1, StructType([StructField("d", DoubleType())])
)

def test_unsupported_types(self):
with self.quiet():
self.check_unsupported_types()
Expand Down

0 comments on commit c5da268

Please sign in to comment.