Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
"<arg_list> should not be set together."
]
},
"CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF": {
"message": [
"returnType can not be specified when `<arg_name>` is a user-defined function, but got <return_type>."
]
},
"COLUMN_IN_LIST": {
"message": [
"`<func_name>` does not allow a Column in a list."
Expand All @@ -89,11 +94,21 @@
"All items in `<arg_name>` should be in <allowed_types>, got <item_type>."
]
},
"INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
"message": [
"Pandas UDF should return StructType for <eval_type>, got <return_type>."
]
},
"INVALID_TIMEOUT_TIMESTAMP" : {
"message" : [
"Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)."
]
},
"INVALID_UDF_EVAL_TYPE" : {
"message" : [
"Eval type for UDF must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
]
},
"INVALID_WHEN_USAGE": {
"message": [
"when() can only be applied on a Column previously generated by when() function, and cannot be applied once otherwise() is applied."
Expand Down
33 changes: 19 additions & 14 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
from pyspark.errors import PySparkTypeError


if TYPE_CHECKING:
Expand Down Expand Up @@ -125,20 +126,24 @@ def __init__(
deterministic: bool = True,
):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
"{0}".format(type(func))
raise PySparkTypeError(
error_class="NOT_CALLABLE",
message_parameters={"arg_name": "func", "arg_type": type(func).__name__},
)

if not isinstance(returnType, (DataType, str)):
raise TypeError(
"Invalid return type: returnType should be DataType or str "
"but is {}".format(returnType)
raise PySparkTypeError(
error_class="NOT_DATATYPE_OR_STR",
message_parameters={
"arg_name": "returnType",
"arg_type": type(returnType).__name__,
},
)

if not isinstance(evalType, int):
raise TypeError(
"Invalid evaluation type: evalType should be an int but is {}".format(evalType)
raise PySparkTypeError(
error_class="NOT_INT",
message_parameters={"arg_name": "evalType", "arg_type": type(evalType).__name__},
)

self.func = func
Expand Down Expand Up @@ -241,9 +246,9 @@ def register(
# Python function.
if hasattr(f, "asNondeterministic"):
if returnType is not None:
raise TypeError(
"Invalid return type: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType
raise PySparkTypeError(
error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
message_parameters={"arg_name": "f", "return_type": str(returnType)},
)
f = cast("UserDefinedFunctionLike", f)
if f.evalType not in [
Expand All @@ -252,9 +257,9 @@ def register(
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
]:
raise ValueError(
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
raise PySparkTypeError(
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={},
)
return_udf = f
self.sparkSession._client.register_udf(
Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
NullType,
TimestampType,
)
from pyspark.errors import PythonException
from pyspark.errors import PythonException, PySparkTypeError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
Expand Down Expand Up @@ -212,12 +212,15 @@ def test_array_type_correct(self):
def test_register_grouped_map_udf(self):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegex(
ValueError,
"f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*",
):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.catalog.registerFunction("foo_udf", foo_udf)

self.check_error(
exception=pe.exception,
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={},
)

def test_decorator(self):
df = self.data

Expand Down
22 changes: 19 additions & 3 deletions python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit
from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType
from pyspark.errors import ParseException, PythonException
from pyspark.errors import ParseException, PythonException, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -153,18 +153,34 @@ def foo(x):
def zero_with_type():
return 1

with self.assertRaisesRegex(TypeError, "Invalid return type"):
with self.assertRaises(PySparkTypeError) as pe:

@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df

with self.assertRaisesRegex(TypeError, "Invalid return type"):
self.check_error(
exception=pe.exception,
error_class="NOT_DATATYPE_OR_STR",
message_parameters={"arg_name": "returnType", "arg_type": "int"},
)

with self.assertRaises(PySparkTypeError) as pe:

@pandas_udf(returnType="double", functionType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df

self.check_error(
exception=pe.exception,
error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
message_parameters={
"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF "
"or SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
"return_type": "DoubleType()",
},
)

with self.assertRaisesRegex(ValueError, "Invalid function"):

@pandas_udf(returnType="k int, v double", functionType=PandasUDFType.GROUPED_MAP)
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
TimestampNTZType,
DayTimeIntervalType,
)
from pyspark.errors import AnalysisException
from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
from pyspark.testing.utils import QuietTest

Expand Down Expand Up @@ -109,11 +109,18 @@ def test_udf_registration_return_type_not_none(self):
self.check_udf_registration_return_type_not_none()

def check_udf_registration_return_type_not_none(self):
with self.assertRaisesRegex(TypeError, "Invalid return type"):
# negative test for incorrect type
with self.assertRaises(PySparkTypeError) as pe:
self.spark.catalog.registerFunction(
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()
)

self.check_error(
exception=pe.exception,
error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
message_parameters={"arg_name": "f", "return_type": "StringType()"},
)

def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
import random
Expand Down
62 changes: 38 additions & 24 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.errors import PySparkTypeError

if TYPE_CHECKING:
from pyspark.sql._typing import DataTypeOrString, ColumnOrName, UserDefinedFunctionLike
Expand Down Expand Up @@ -218,20 +219,24 @@ def __init__(
deterministic: bool = True,
):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
"{0}".format(type(func))
raise PySparkTypeError(
error_class="NOT_CALLABLE",
message_parameters={"arg_name": "func", "arg_type": type(func).__name__},
)

if not isinstance(returnType, (DataType, str)):
raise TypeError(
"Invalid return type: returnType should be DataType or str "
"but is {}".format(returnType)
raise PySparkTypeError(
error_class="NOT_DATATYPE_OR_STR",
message_parameters={
"arg_name": "returnType",
"arg_type": type(returnType).__name__,
},
)

if not isinstance(evalType, int):
raise TypeError(
"Invalid evaluation type: evalType should be an int but is {}".format(evalType)
raise PySparkTypeError(
error_class="NOT_INT",
message_parameters={"arg_name": "evalType", "arg_type": type(evalType).__name__},
)

self.func = func
Expand Down Expand Up @@ -280,10 +285,13 @@ def returnType(self) -> DataType:
% str(self._returnType_placeholder)
)
else:
raise TypeError(
"Invalid return type for grouped map Pandas "
"UDFs or at groupby.applyInPandas(WithState): return type must be a "
"StructType."
raise PySparkTypeError(
error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
message_parameters={
"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
"return_type": str(self._returnType_placeholder),
},
)
elif (
self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
Expand All @@ -298,9 +306,12 @@ def returnType(self) -> DataType:
"%s is not supported" % str(self._returnType_placeholder)
)
else:
raise TypeError(
"Invalid return type in mapInPandas/mapInArrow: "
"return type must be a StructType."
raise PySparkTypeError(
error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
message_parameters={
"eval_type": "SQL_MAP_PANDAS_ITER_UDF or SQL_MAP_ARROW_ITER_UDF",
"return_type": str(self._returnType_placeholder),
},
)
elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
Expand All @@ -312,9 +323,12 @@ def returnType(self) -> DataType:
"%s is not supported" % str(self._returnType_placeholder)
)
else:
raise TypeError(
"Invalid return type in cogroup.applyInPandas: "
"return type must be a StructType."
raise PySparkTypeError(
error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
message_parameters={
"eval_type": "SQL_COGROUPED_MAP_PANDAS_UDF",
"return_type": str(self._returnType_placeholder),
},
)
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
Expand Down Expand Up @@ -591,9 +605,9 @@ def register(
# Python function.
if hasattr(f, "asNondeterministic"):
if returnType is not None:
raise TypeError(
"Invalid return type: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType
raise PySparkTypeError(
error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
message_parameters={"arg_name": "f", "return_type": str(returnType)},
)
f = cast("UserDefinedFunctionLike", f)
if f.evalType not in [
Expand All @@ -602,9 +616,9 @@ def register(
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
]:
raise ValueError(
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
raise PySparkTypeError(
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={},
)
register_udf = _create_udf(
f.func,
Expand Down