Skip to content

Commit

Permalink
[SPARK-44839][SS][CONNECT] Better Error Logging when user tries to se…
Browse files Browse the repository at this point in the history
…rialize spark session

### What changes were proposed in this pull request?

Add a new error with detailed message when a user tries to access spark session and dataframe created using local spark session, in streaming spark connect `foreach`, `foreachBatch` and `StreamingQueryListener`.

Update: per reviewer's request, added a new error class `PySparkPicklingError`. Also move `UDTF_SERIALIZATION_ERROR` to the new class

### Why are the changes needed?

Better error logging for the breaking change introduced in streaming spark connect.

### Does this PR introduce _any_ user-facing change?

Yes, before users can only see this non-informative error when they access a local spark session in their streaming connect related functions:
```
Traceback (most recent call last):
  File "/home/wei.liu/oss-spark/python/pyspark/serializers.py", line 459, in dumps
    return cloudpickle.dumps(obj, pickle_protocol)
  File "/home/wei.liu/oss-spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/home/wei.liu/oss-spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 632, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle '_thread._local' object

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/wei.liu/oss-spark/python/pyspark/sql/connect/streaming/readwriter.py", line 508, in foreachBatch
    self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
  File "/home/wei.liu/oss-spark/python/pyspark/serializers.py", line 469, in dumps
    raise pickle.PicklingError(msg)
_pickle.PicklingError: Could not serialize object: TypeError: cannot pickle '_thread._local' object
```

Now it is replaced with:
```
pyspark.errors.exceptions.base.PySparkPicklingError: [STREAMING_CONNECT_SERIALIZATION_ERROR] Cannot serialize the function `foreachBatch`. If you accessed the spark session, or a dataframe defined outside of the function, please be aware that they are not allowed in Spark Connect. For foreachBatch, please access the spark session using `df.sparkSession`, where `df` is the first parameter in your foreachBatch function. For StreamingQueryListener, please access the spark session using `self.spark`
```

### How was this patch tested?

Add unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#42594 from WweiL/SPARK-44839-spark-session-error.

Lead-authored-by: Wei Liu <wei.liu@databricks.com>
Co-authored-by: Wei Liu <z920631580@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
2 people authored and HyukjinKwon committed Aug 25, 2023
1 parent bf3bef1 commit 0b3a582
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 16 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PySparkRuntimeError,
PySparkAssertionError,
PySparkNotImplementedError,
PySparkPicklingError,
)


Expand All @@ -67,4 +68,5 @@
"PySparkRuntimeError",
"PySparkAssertionError",
"PySparkNotImplementedError",
"PySparkPicklingError",
]
5 changes: 5 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,11 @@
"pandas iterator UDF should exhaust the input iterator."
]
},
"STREAMING_CONNECT_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the function `<name>`. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`."
]
},
"TOO_MANY_VALUES" : {
"message" : [
"Expected <expected> values for `<item>`, got <actual>."
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/errors/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, Optional, cast

from pyspark.errors.utils import ErrorClassesReader
from pickle import PicklingError


class PySparkException(Exception):
Expand Down Expand Up @@ -226,3 +227,9 @@ class PySparkNotImplementedError(PySparkException, NotImplementedError):
"""
Wrapper class for NotImplementedError to support error classes.
"""


class PySparkPicklingError(PySparkException, PicklingError):
"""
Wrapper class for pickle.PicklingError to support error classes.
"""
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.errors import (
PySparkTypeError,
PySparkNotImplementedError,
PySparkRuntimeError,
PySparkPicklingError,
IllegalArgumentException,
)

Expand Down Expand Up @@ -2211,7 +2211,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
try:
udtf.command = CloudPickleSerializer().dumps(self._func)
except pickle.PicklingError:
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional

from pyspark.errors import StreamingQueryException, PySparkValueError
Expand All @@ -32,6 +33,7 @@
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkPicklingError

__all__ = ["StreamingQuery", "StreamingQueryManager"]

Expand Down Expand Up @@ -237,7 +239,13 @@ def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
expr.command = CloudPickleSerializer().dumps(listener)
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
Expand Down
27 changes: 20 additions & 7 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
check_dependencies(__name__)

import sys
import pickle
from typing import cast, overload, Callable, Dict, List, Optional, TYPE_CHECKING, Union

from pyspark.serializers import CloudPickleSerializer
Expand All @@ -33,7 +34,7 @@
)
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.types import Row, StructType
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -488,18 +489,30 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt
serializer = AutoBatchedSerializer(CPickleSerializer())
command = (func, None, serializer, serializer)
# Python ForeachWriter isn't really a PythonUDF. But we reuse it for simplicity.
self._write_proto.foreach_writer.python_function.command = CloudPickleSerializer().dumps(
command
)
try:
self._write_proto.foreach_writer.python_function.command = (
CloudPickleSerializer().dumps(command)
)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreach"},
)
self._write_proto.foreach_writer.python_function.python_ver = "%d.%d" % sys.version_info[:2]
return self

foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__

def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter":
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
try:
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreachBatch"},
)
self._write_proto.foreach_batch.python_function.python_ver = get_python_ver()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.errors import PySparkPicklingError


class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase):
Expand All @@ -30,6 +31,35 @@ def test_streaming_foreachBatch_propagates_python_errors(self):
def test_streaming_foreachBatch_graceful_stop(self):
super().test_streaming_foreachBatch_graceful_stop()

# class StreamingForeachBatchParityTests(ReusedConnectTestCase):
def test_accessing_spark_session(self):
spark = self.spark

def func(df, _):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

def func(df, _):
dataframe.collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time

import pyspark.cloudpickle
from pyspark.errors import PySparkPicklingError
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.functions import count, lit
Expand Down Expand Up @@ -94,6 +95,54 @@ def test_listener_events(self):
# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

def test_accessing_spark_session(self):
spark = self.spark

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
dataframe.collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkPicklingError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PythonException,
PySparkTypeError,
AnalysisException,
PySparkRuntimeError,
PySparkPicklingError,
)
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -872,7 +872,7 @@ def eval(self):
file_obj
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_access_spark_session(self):
Expand All @@ -884,7 +884,7 @@ def eval(self):
df.collect()
yield 1,

with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"):
with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"):
TestUDTF().collect()

def test_udtf_no_eval(self):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from py4j.java_gateway import JavaObject

from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError
from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
Expand Down Expand Up @@ -234,7 +234,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
wrapped_func = _wrap_function(sc, func)
except pickle.PicklingError as e:
if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand All @@ -244,7 +244,7 @@ def _create_judtf(self, func: Type) -> JavaObject:
"and try again.",
},
) from None
raise PySparkRuntimeError(
raise PySparkPicklingError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
Expand Down

0 comments on commit 0b3a582

Please sign in to comment.