Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 29, 2020
1 parent fb22d68 commit 28c2a51
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
14 changes: 6 additions & 8 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import ParseException
from pyspark.sql.utils import ParseException, PythonException
from pyspark.rdd import PythonEvalType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

from py4j.protocol import Py4JJavaError


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down Expand Up @@ -157,14 +155,14 @@ def foofoo(x, y):

# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(
Py4JJavaError,
PythonException,
exc_message,
df.withColumn('v', udf(foo)('id')).collect
)

# pandas scalar udf
self.assertRaisesRegexp(
Py4JJavaError,
PythonException,
exc_message,
df.withColumn(
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
Expand All @@ -173,15 +171,15 @@ def foofoo(x, y):

# pandas grouped map
self.assertRaisesRegexp(
Py4JJavaError,
PythonException,
exc_message,
df.groupBy('id').apply(
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
).collect
)

self.assertRaisesRegexp(
Py4JJavaError,
PythonException,
exc_message,
df.groupBy('id').apply(
pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
Expand All @@ -190,7 +188,7 @@ def foofoo(x, y):

# pandas grouped agg
self.assertRaisesRegexp(
Py4JJavaError,
PythonException,
exc_message,
df.groupBy('id').agg(
pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def convert_exception(e):
return QueryExecutionException(s.split(': ', 1)[1], stacktrace, c)
if s.startswith('java.lang.IllegalArgumentException: '):
return IllegalArgumentException(s.split(': ', 1)[1], stacktrace, c)
if c is not None and c.toString().startswith('org.apache.spark.api.python.PythonException: '):
if c is not None and (
c.toString().startswith('org.apache.spark.api.python.PythonException: ')
# To make sure this only catches Python UDFs.
and any(map(lambda v: "org.apache.spark.sql.execution.python" in v.toString(),
c.getStackTrace()))):
msg = ("\n An exception was thrown from Python worker in the executor. "
"The below is the Python worker stacktrace.\n%s" % c.getMessage())
return PythonException(msg, stacktrace)
Expand Down

0 comments on commit 28c2a51

Please sign in to comment.