diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py index a9bcb973a6fcf..53d5a266557f5 100644 --- a/python/pyspark/errors/__init__.py +++ b/python/pyspark/errors/__init__.py @@ -41,6 +41,7 @@ PySparkRuntimeError, PySparkAssertionError, PySparkNotImplementedError, + PySparkPicklingError, ) @@ -67,4 +68,5 @@ "PySparkRuntimeError", "PySparkAssertionError", "PySparkNotImplementedError", + "PySparkPicklingError", ] diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 9b5d8954c20ae..ca448a169e83b 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -718,6 +718,11 @@ "pandas iterator UDF should exhaust the input iterator." ] }, + "STREAMING_CONNECT_SERIALIZATION_ERROR" : { + "message" : [ + "Cannot serialize the function ``. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`." + ] + }, "TOO_MANY_VALUES" : { "message" : [ "Expected values for ``, got ." diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index fd1c07c4df68d..1d09a68dffbfe 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -18,6 +18,7 @@ from typing import Dict, Optional, cast from pyspark.errors.utils import ErrorClassesReader +from pickle import PicklingError class PySparkException(Exception): @@ -226,3 +227,9 @@ class PySparkNotImplementedError(PySparkException, NotImplementedError): """ Wrapper class for NotImplementedError to support error classes. """ + + +class PySparkPicklingError(PySparkException, PicklingError): + """ + Wrapper class for pickle.PicklingError to support error classes. + """ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 4ef789e28a527..7952d2af99958 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -44,7 +44,7 @@ from pyspark.errors import ( PySparkTypeError, PySparkNotImplementedError, - PySparkRuntimeError, + PySparkPicklingError, IllegalArgumentException, ) @@ -2211,7 +2211,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF: try: udtf.command = CloudPickleSerializer().dumps(self._func) except pickle.PicklingError: - raise PySparkRuntimeError( + raise PySparkPicklingError( error_class="UDTF_SERIALIZATION_ERROR", message_parameters={ "name": self._name, diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 021d27e939de8..7cebc0e71ef1b 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -17,6 +17,7 @@ import json import sys +import pickle from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional from pyspark.errors import StreamingQueryException, PySparkValueError @@ -32,6 +33,7 @@ from pyspark.errors.exceptions.connect import ( StreamingQueryException as CapturedStreamingQueryException, ) +from pyspark.errors import PySparkPicklingError __all__ = ["StreamingQuery", "StreamingQueryManager"] @@ -237,7 +239,13 @@ def addListener(self, listener: StreamingQueryListener) -> None: listener._init_listener_id() cmd = pb2.StreamingQueryManagerCommand() expr = proto.PythonUDF() - expr.command = CloudPickleSerializer().dumps(listener) + try: + expr.command = CloudPickleSerializer().dumps(listener) + except pickle.PicklingError: + raise PySparkPicklingError( + error_class="STREAMING_CONNECT_SERIALIZATION_ERROR", + message_parameters={"name": "addListener"}, + ) expr.python_ver = get_python_ver() cmd.add_listener.python_listener_payload.CopyFrom(expr) cmd.add_listener.id = listener._id diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 89097fcf43a01..63ec7848d1eb0 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -20,6 +20,7 @@ check_dependencies(__name__) import sys +import pickle from typing import cast, overload, Callable, Dict, List, Optional, TYPE_CHECKING, Union from pyspark.serializers import CloudPickleSerializer @@ -33,7 +34,7 @@ ) from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.types import Row, StructType -from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkPicklingError if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession @@ -488,18 +489,30 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt serializer = AutoBatchedSerializer(CPickleSerializer()) command = (func, None, serializer, serializer) # Python ForeachWriter isn't really a PythonUDF. But we reuse it for simplicity. - self._write_proto.foreach_writer.python_function.command = CloudPickleSerializer().dumps( - command - ) + try: + self._write_proto.foreach_writer.python_function.command = ( + CloudPickleSerializer().dumps(command) + ) + except pickle.PicklingError: + raise PySparkPicklingError( + error_class="STREAMING_CONNECT_SERIALIZATION_ERROR", + message_parameters={"name": "foreach"}, + ) self._write_proto.foreach_writer.python_function.python_ver = "%d.%d" % sys.version_info[:2] return self foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": - self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps( - func - ) + try: + self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps( + func + ) + except pickle.PicklingError: + raise PySparkPicklingError( + error_class="STREAMING_CONNECT_SERIALIZATION_ERROR", + message_parameters={"name": "foreachBatch"}, + ) self._write_proto.foreach_batch.python_function.python_ver = get_python_ver() return self diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py index 01108c95391bf..0718c6a88b0da 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py @@ -19,6 +19,7 @@ from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.errors import PySparkPicklingError class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase): @@ -30,6 +31,35 @@ def test_streaming_foreachBatch_propagates_python_errors(self): def test_streaming_foreachBatch_graceful_stop(self): super().test_streaming_foreachBatch_graceful_stop() + # class StreamingForeachBatchParityTests(ReusedConnectTestCase): + def test_accessing_spark_session(self): + spark = self.spark + + def func(df, _): + spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect() + + error_thrown = False + try: + self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() + except PySparkPicklingError as e: + self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") + error_thrown = True + self.assertTrue(error_thrown) + + def test_accessing_spark_session_through_df(self): + dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")]) + + def func(df, _): + dataframe.collect() + + error_thrown = False + try: + self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() + except PySparkPicklingError as e: + self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") + error_thrown = True + self.assertTrue(error_thrown) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 5069a76cfdb73..ca02cf29ee7d1 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -19,6 +19,7 @@ import time import pyspark.cloudpickle +from pyspark.errors import PySparkPicklingError from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin from pyspark.sql.streaming.listener import StreamingQueryListener from pyspark.sql.functions import count, lit @@ -94,6 +95,54 @@ def test_listener_events(self): # Remove again to verify this won't throw any error self.spark.streams.removeListener(test_listener) + def test_accessing_spark_session(self): + spark = self.spark + + class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect() + + def onQueryProgress(self, event): + pass + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + pass + + error_thrown = False + try: + self.spark.streams.addListener(TestListener()) + except PySparkPicklingError as e: + self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") + error_thrown = True + self.assertTrue(error_thrown) + + def test_accessing_spark_session_through_df(self): + dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")]) + + class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + dataframe.collect() + + def onQueryProgress(self, event): + pass + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + pass + + error_thrown = False + try: + self.spark.streams.addListener(TestListener()) + except PySparkPicklingError as e: + self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR") + error_thrown = True + self.assertTrue(error_thrown) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 63743de5e0322..a7545c332e6a0 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -27,7 +27,7 @@ PythonException, PySparkTypeError, AnalysisException, - PySparkRuntimeError, + PySparkPicklingError, ) from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -872,7 +872,7 @@ def eval(self): file_obj yield 1, - with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"): TestUDTF().collect() def test_udtf_access_spark_session(self): @@ -884,7 +884,7 @@ def eval(self): df.collect() yield 1, - with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + with self.assertRaisesRegex(PySparkPicklingError, "UDTF_SERIALIZATION_ERROR"): TestUDTF().collect() def test_udtf_no_eval(self): diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 3658ce9a1d813..72bba3d9a2c48 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -26,7 +26,7 @@ from py4j.java_gateway import JavaObject -from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError +from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError from pyspark.rdd import PythonEvalType from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -234,7 +234,7 @@ def _create_judtf(self, func: Type) -> JavaObject: wrapped_func = _wrap_function(sc, func) except pickle.PicklingError as e: if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e): - raise PySparkRuntimeError( + raise PySparkPicklingError( error_class="UDTF_SERIALIZATION_ERROR", message_parameters={ "name": self._name, @@ -244,7 +244,7 @@ def _create_judtf(self, func: Type) -> JavaObject: "and try again.", }, ) from None - raise PySparkRuntimeError( + raise PySparkPicklingError( error_class="UDTF_SERIALIZATION_ERROR", message_parameters={ "name": self._name,