-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors #24834
[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors #24834
Conversation
I think might be a better way to propagate exceptions from the Python connection serving thread for the cases of Here I duplicated the |
From the discussion in #24677 , regarding the
|
* This is the same as serveToStream, only it returns a server object that | ||
* can be used to sync in Python. | ||
*/ | ||
private[spark] def serveToStreamWithSync( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be cleaned up and replace the existing serveToStream
. It just returns the SocketAuthServer
object as the third element in the Array, and it could be ignored if no synchronization is needed.
private [spark] class SocketFuncServer( | ||
authHelper: SocketAuthHelper, | ||
threadName: String, | ||
func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need SockAuthServer.setupOneConnectionServer
if we have this also, so it could be cleaned up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed SockAuthServer.setupOneConnectionServer
and replaced usage with SocketFuncServer
python/pyspark/sql/dataframe.py
Outdated
|
||
# Collect list of un-ordered batches where last element is a list of correct order indices | ||
results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) | ||
from pyspark.rdd import _create_local_socket |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be cleaned up. below basically duplicates _load_from_socket
I'm attaching error messages from
To sum up the differences: (3) this PR, it is more similar to (1) in that it has a Py4JJavaError and the |
What do you guys think @HyukjinKwon @felixcheung @dvogelbacher ? |
Test build #106364 has finished for PR 24834 at commit
|
I really like the idea. This is much better than having to define a specific protocol for propagating errors like we currently do. From having a short look at the R code it seems like R would also be affected by https://jira.apache.org/jira/browse/SPARK-27805 and using this same mechanism in R would fix it there, too? |
python/pyspark/sql/dataframe.py
Outdated
from pyspark.rdd import _create_local_socket | ||
sock_file = _create_local_socket((port, auth_secret)) | ||
results = list(ArrowCollectSerializer().load_stream(sock_file)) | ||
jserver_obj.getResult() # Join serving thread and raise any exceptions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might want to have this in a finally
clause, so that if we have an error during serialization (which might be caused by an exception in the JVM) we will still get the original exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks reasonable. agreed on @dvogelbacher comment
Thanks @dvogelbacher and @felixcheung , I will clean this up then and apply the same fix to |
python/pyspark/sql/dataframe.py
Outdated
results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) | ||
from pyspark.rdd import _create_local_socket | ||
sock_file = _create_local_socket((port, auth_secret)) | ||
results = list(ArrowCollectSerializer().load_stream(sock_file)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea .. looks same as _load_from_socket
..
python/pyspark/sql/dataframe.py
Outdated
@@ -2200,10 +2200,13 @@ def _collectAsArrow(self): | |||
.. note:: Experimental. | |||
""" | |||
with SCCallSiteSync(self._sc) as css: | |||
sock_info = self._jdf.collectAsArrowToPython() | |||
port, auth_secret, jserver_obj = self._jdf.collectAsArrowToPython() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think if it makes sense to make a _serialize_from_jvm
(like _serialize_to_jvm
)? This could be done separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly, _load_from_socket
is basically deserializing from the JVM
@BryanCutler, I don't mind but strongly feel about backporting. If you think we should, we can do. |
The approach looks fine. |
f20a156
to
ead8978
Compare
JavaUtils.closeQuietly(serverSocket) | ||
JavaUtils.closeQuietly(sock) | ||
} | ||
def serveToStream( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this from SocketAuthHelper
because it seemed more fitting to be here
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging { | |||
originalThrowable = cause | |||
try { | |||
logError("Aborting task", originalThrowable) | |||
TaskContext.get().markTaskFailed(originalThrowable) | |||
if (TaskContext.get() != null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using this utility here https://github.com/apache/spark/pull/24834/files#diff-0a67bc4d171abe4df8eb305b0f4123a2R184, where the task fails and completes before hitting the catchBlock
, so TaskContext.get()
returns a null
I cleaned up some things with |
Test build #106739 has finished for PR 24834 at commit
|
Test build #106740 has finished for PR 24834 at commit
|
Let me give some input within 3 days .. |
@@ -66,42 +87,45 @@ private[spark] abstract class SocketAuthServer[T]( | |||
|
|||
} | |||
|
|||
/** | |||
* Create a socket server class and run user function on the socket in a background thread. | |||
* This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we don't have setupOneConnectionServer anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, good catch
""" | ||
port = sock_info[0] | ||
auth_secret = sock_info[1] | ||
sockfile, sock = local_connect_and_auth(port, auth_secret) | ||
# The RDD materialization time is unpredictable, if we set a timeout for socket reading | ||
# operation, it will very possibly fail. See SPARK-18281. | ||
sock.settimeout(None) | ||
return sockfile | ||
|
||
|
||
def _load_from_socket(sock_info, serializer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BryanCutler, what does sock_info
expect to be? Seems it can be both 2-tuple and 3-tuple (with server).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uggh, yeah I'm not too happy with this. Java returns a 3-tuple with (port, auth_secret, server) and most places only use the first 2, such as _load_from_socket
. It gets a little confusing, so I thought it might be better to expand the values returned by java for serveToStream
etc., but it ended up with a lot of changes where the third value is ignored like this
port, auth_secret, _ = ...
and I don't think it really made things clearer. I'll try to think of something better and maybe do a followup.
R side, I will take a look after merging this in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good except that .. I think it's now difficult to read sock_info
.. https://github.com/apache/spark/pull/24834/files#r296549441 ..
BTW, let's avoid refactor the codes with a fix .. it took me a while to understand the fix + track the changes ..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally makes sense to me -- just a few comments on things which would help me follow this code when I occasionally look at it, more as a java developer that occasionally needs to work on the communication protocol.
One important thing -- please change the summary / description to replace "synchronize" with "join".
* Create a socket server and run user function on the socket in a background thread. | ||
* | ||
* The socket server can only accept one connection, or close if no connection | ||
* in 15 seconds. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please save this comment -- I guess move it to the class SocketAuthServer
. In particular ,its helpful to note that this only accepts one connection, its not a long-lived thing which is reused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that is a useful comment, I didn't intend to take this out. I'll put it back in.
python/pyspark/rdd.py
Outdated
@@ -159,7 +174,8 @@ class PyLocalIterable(object): | |||
""" Create a synchronous local iterable over a socket """ | |||
|
|||
def __init__(self, _sock_info, _serializer): | |||
self._sockfile = _create_local_socket(_sock_info) | |||
port, auth_secret, self.jserver_obj = _sock_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a general comment, might not make the most sense to address in this particular PR -- I'd find it really helpful if the python code which is dealing w/ java objects would annotate (somehow) the java types. Its hard for me to figure out if jserver_obj
is a ServerSocket or a SocketAuthServer or Py4JJavaServer etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I can rename it to something more fitting and I agree it should be clear what the variable is by the name
Test build #106844 has finished for PR 24834 at commit
|
Thanks @HyukjinKwon and @squito for reviewing, I addressed your comments. |
merged to master, thanks all for reviewing! |
…nection thread to propagate errors ### What changes were proposed in this pull request? This PR proposes to backport #24834 with minimised changes, and the tests added at #25594. #24834 was not backported before because basically it targeted a better exception by propagating the exception from JVM. However, actually this PR fixed another problem accidentally (see #25594 and [SPARK-28881](https://issues.apache.org/jira/browse/SPARK-28881)). This regression seems introduced by #21546. Root cause is that, seems https://github.com/apache/spark/blob/23bed0d3c08e03085d3f0c3a7d457eedd30bd67f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L3370-L3384 `runJob` with `resultHandler` seems able to write partial output. JVM throws an exception but, since the JVM exception is not propagated into Python process, Python process doesn't know if the exception is thrown or not from JVM (it just closes the socket), which results as below: ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Empty DataFrame Columns: [id] Index: [] ``` With this change, it lets Python process catches exceptions from JVM. ### Why are the changes needed? It returns incorrect data. And potentially it returns partial results when an exception happens in JVM sides. This is a regression. The codes work fine in Spark 2.3.3. ### Does this PR introduce any user-facing change? Yes. ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../pyspark/sql/dataframe.py", line 2122, in toPandas batches = self._collectAsArrow() File "/.../pyspark/sql/dataframe.py", line 2184, in _collectAsArrow jsocket_auth_server.getResult() # Join serving thread and raise any exceptions File "/.../lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/.../pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o42.getResult. : org.apache.spark.SparkException: Exception thrown in awaitResult: ... Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 1 tasks (6.5 MB) is bigger than spark.driver.maxResultSize (1024.0 KB) ``` now throws an exception as expected. ### How was this patch tested? Manually as described above. unittest added. Closes #25593 from HyukjinKwon/SPARK-27992. Lead-authored-by: HyukjinKwon <gurwls223@apache.org> Co-authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
…nection thread to propagate errors ### What changes were proposed in this pull request? This PR proposes to backport apache#24834 with minimised changes, and the tests added at apache#25594. apache#24834 was not backported before because basically it targeted a better exception by propagating the exception from JVM. However, actually this PR fixed another problem accidentally (see apache#25594 and [SPARK-28881](https://issues.apache.org/jira/browse/SPARK-28881)). This regression seems introduced by apache#21546. Root cause is that, seems https://github.com/apache/spark/blob/23bed0d3c08e03085d3f0c3a7d457eedd30bd67f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L3370-L3384 `runJob` with `resultHandler` seems able to write partial output. JVM throws an exception but, since the JVM exception is not propagated into Python process, Python process doesn't know if the exception is thrown or not from JVM (it just closes the socket), which results as below: ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Empty DataFrame Columns: [id] Index: [] ``` With this change, it lets Python process catches exceptions from JVM. ### Why are the changes needed? It returns incorrect data. And potentially it returns partial results when an exception happens in JVM sides. This is a regression. The codes work fine in Spark 2.3.3. ### Does this PR introduce any user-facing change? Yes. ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../pyspark/sql/dataframe.py", line 2122, in toPandas batches = self._collectAsArrow() File "/.../pyspark/sql/dataframe.py", line 2184, in _collectAsArrow jsocket_auth_server.getResult() # Join serving thread and raise any exceptions File "/.../lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/.../pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o42.getResult. : org.apache.spark.SparkException: Exception thrown in awaitResult: ... Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 1 tasks (6.5 MB) is bigger than spark.driver.maxResultSize (1024.0 KB) ``` now throws an exception as expected. ### How was this patch tested? Manually as described above. unittest added. Closes apache#25593 from HyukjinKwon/SPARK-27992. Lead-authored-by: HyukjinKwon <gurwls223@apache.org> Co-authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
…nection thread to propagate errors ### What changes were proposed in this pull request? This PR proposes to backport apache#24834 with minimised changes, and the tests added at apache#25594. apache#24834 was not backported before because basically it targeted a better exception by propagating the exception from JVM. However, actually this PR fixed another problem accidentally (see apache#25594 and [SPARK-28881](https://issues.apache.org/jira/browse/SPARK-28881)). This regression seems introduced by apache#21546. Root cause is that, seems https://github.com/apache/spark/blob/23bed0d3c08e03085d3f0c3a7d457eedd30bd67f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L3370-L3384 `runJob` with `resultHandler` seems able to write partial output. JVM throws an exception but, since the JVM exception is not propagated into Python process, Python process doesn't know if the exception is thrown or not from JVM (it just closes the socket), which results as below: ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Empty DataFrame Columns: [id] Index: [] ``` With this change, it lets Python process catches exceptions from JVM. ### Why are the changes needed? It returns incorrect data. And potentially it returns partial results when an exception happens in JVM sides. This is a regression. The codes work fine in Spark 2.3.3. ### Does this PR introduce any user-facing change? Yes. ``` ./bin/pyspark --conf spark.driver.maxResultSize=1m ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled",True) spark.range(10000000).toPandas() ``` ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../pyspark/sql/dataframe.py", line 2122, in toPandas batches = self._collectAsArrow() File "/.../pyspark/sql/dataframe.py", line 2184, in _collectAsArrow jsocket_auth_server.getResult() # Join serving thread and raise any exceptions File "/.../lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ File "/.../pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o42.getResult. : org.apache.spark.SparkException: Exception thrown in awaitResult: ... Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 1 tasks (6.5 MB) is bigger than spark.driver.maxResultSize (1024.0 KB) ``` now throws an exception as expected. ### How was this patch tested? Manually as described above. unittest added. Closes apache#25593 from HyukjinKwon/SPARK-27992. Lead-authored-by: HyukjinKwon <gurwls223@apache.org> Co-authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
…propagate errors ## What changes were proposed in this pull request? Currently with `toLocalIterator()` and `toPandas()` with Arrow enabled, if the Spark job being run in the background serving thread errors, it will be caught and sent to Python through the PySpark serializer. This is not the ideal solution because it is only catch a SparkException, it won't handle an error that occurs in the serializer, and each method has to have it's own special handling to propagate the error. This PR instead returns the Python Server object along with the serving port and authentication info, so that it allows the Python caller to join with the serving thread. During the call to join, the serving thread Future is completed either successfully or with an exception. In the latter case, the exception will be propagated to Python through the Py4j call. ## How was this patch tested? Existing tests Closes apache#24834 from BryanCutler/pyspark-propagate-server-error-SPARK-27992. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
…propagate errors ## What changes were proposed in this pull request? Currently with `toLocalIterator()` and `toPandas()` with Arrow enabled, if the Spark job being run in the background serving thread errors, it will be caught and sent to Python through the PySpark serializer. This is not the ideal solution because it is only catch a SparkException, it won't handle an error that occurs in the serializer, and each method has to have it's own special handling to propagate the error. This PR instead returns the Python Server object along with the serving port and authentication info, so that it allows the Python caller to join with the serving thread. During the call to join, the serving thread Future is completed either successfully or with an exception. In the latter case, the exception will be propagated to Python through the Py4j call. ## How was this patch tested? Existing tests Closes apache#24834 from BryanCutler/pyspark-propagate-server-error-SPARK-27992. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
What changes were proposed in this pull request?
Currently with
toLocalIterator()
andtoPandas()
with Arrow enabled, if the Spark job being run in the background serving thread errors, it will be caught and sent to Python through the PySpark serializer.This is not the ideal solution because it is only catch a SparkException, it won't handle an error that occurs in the serializer, and each method has to have it's own special handling to propagate the error.
This PR instead returns the Python Server object along with the serving port and authentication info, so that it allows the Python caller to join with the serving thread. During the call to join, the serving thread Future is completed either successfully or with an exception. In the latter case, the exception will be propagated to Python through the Py4j call.
How was this patch tested?
Existing tests