Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ def close(self) -> None:
"""
Close the channel.
"""
ExecutePlanResponseReattachableIterator.shutdown()
self._channel.close()
self._closed = True

Expand Down
37 changes: 32 additions & 5 deletions python/pyspark/sql/connect/client/reattach.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import warnings
import uuid
from collections.abc import Generator
from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast
from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from multiprocessing import RLock
from multiprocessing.synchronize import RLock as RLockBase
from multiprocessing.pool import ThreadPool
import os

Expand Down Expand Up @@ -53,7 +55,30 @@ class ExecutePlanResponseReattachableIterator(Generator):
ReleaseExecute RPCs that instruct the server to release responses that it already processed.
"""

_release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8)
# Lock to manage the pool
_lock: ClassVar[RLockBase] = RLock()
_release_thread_pool: Optional[ThreadPool] = ThreadPool(os.cpu_count() if os.cpu_count() else 8)

@classmethod
def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None:
"""
When the channel is closed, this method will be called before, to make sure all
outstanding calls are closed.
"""
with cls._lock:
if cls._release_thread_pool is not None:
cls._release_thread_pool.close()
cls._release_thread_pool.join()
cls._release_thread_pool = None
Comment on lines +62 to +72
Copy link
Contributor

@juliuszsompolski juliuszsompolski Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen (and ignored for now...) the scala equivalent of this failing when we do SparkConnectClient.shutdown, which does channel.shutdownNow(). In scala, we don't have a dedicated threadpool for that, but (ab)use a grpc thread in https://github.com/apache/spark/blob/master/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala#L179
I wonder if more graceful shutdown of the channel would fixed it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@classmethod
def _initialize_pool_if_necessary(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None:
"""
If the processing pool for the release calls is None, initialize the pool exactly once.
"""
with cls._lock:
if cls._release_thread_pool is None:
cls._release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8)

def __init__(
self,
Expand All @@ -62,6 +87,7 @@ def __init__(
retry_policy: Dict[str, Any],
metadata: Iterable[Tuple[str, str]],
):
ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary()
self._request = request
self._retry_policy = retry_policy
if request.operation_id:
Expand Down Expand Up @@ -111,7 +137,6 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse:

self._last_returned_response_id = ret.response_id
if ret.HasField("result_complete"):
self._result_complete = True
self._release_all()
else:
self._release_until(self._last_returned_response_id)
Expand Down Expand Up @@ -190,7 +215,8 @@ def target() -> None:
except Exception as e:
warnings.warn(f"ReleaseExecute failed with exception: {e}.")

ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
if ExecutePlanResponseReattachableIterator._release_thread_pool is not None:
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)

def _release_all(self) -> None:
"""
Expand Down Expand Up @@ -218,7 +244,8 @@ def target() -> None:
except Exception as e:
warnings.warn(f"ReleaseExecute failed with exception: {e}.")

ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
if ExecutePlanResponseReattachableIterator._release_thread_pool is not None:
ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target)
self._result_complete = True

def _call_iter(self, iter_fun: Callable) -> Any:
Expand Down
70 changes: 54 additions & 16 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ def setUp(self) -> None:
"jitter": 10,
"min_jitter_threshold": 10,
}
self.response = proto.ExecutePlanResponse()
self.response = proto.ExecutePlanResponse(
response_id="1",
)
self.finished = proto.ExecutePlanResponse(
result_complete=proto.ExecutePlanResponse.ResultComplete()
result_complete=proto.ExecutePlanResponse.ResultComplete(),
response_id="2",
)

def _stub_with(self, execute=None, attach=None):
Expand All @@ -147,15 +150,33 @@ def _stub_with(self, execute=None, attach=None):
attach_ops=ResponseGenerator(attach) if attach is not None else None,
)

def assertEventually(self, callable, timeout_ms=1000):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's eventually at pyspark.testing.utils. I can follow up

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Helper method that will continuously evaluate the callable to not raise an
exception."""
import time

limit = time.monotonic_ns() + timeout_ms * 1000 * 1000
while time.monotonic_ns() < limit:
try:
callable()
break
except Exception:
time.sleep(0.1)
callable()

def test_basic_flow(self):
stub = self._stub_with([self.response, self.finished])
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, [])
for b in ite:
pass

self.assertEqual(0, stub.attach_calls)
self.assertGreater(1, stub.release_calls)
self.assertEqual(1, stub.execute_calls)
def check_all():
self.assertEqual(0, stub.attach_calls)
self.assertEqual(1, stub.release_until_calls)
self.assertEqual(1, stub.release_calls)
self.assertEqual(1, stub.execute_calls)

self.assertEventually(check_all, timeout_ms=1000)

def test_fail_during_execute(self):
def fatal():
Expand All @@ -167,9 +188,13 @@ def fatal():
for b in ite:
pass

self.assertEqual(0, stub.attach_calls)
self.assertEqual(0, stub.release_calls)
self.assertEqual(1, stub.execute_calls)
def check():
self.assertEqual(0, stub.attach_calls)
self.assertEqual(1, stub.release_calls)
self.assertEqual(1, stub.release_until_calls)
self.assertEqual(1, stub.execute_calls)

self.assertEventually(check, timeout_ms=1000)

def test_fail_and_retry_during_execute(self):
def non_fatal():
Expand All @@ -182,9 +207,13 @@ def non_fatal():
for b in ite:
pass

self.assertEqual(1, stub.attach_calls)
self.assertEqual(1, stub.release_calls)
self.assertEqual(1, stub.execute_calls)
def check():
self.assertEqual(1, stub.attach_calls)
self.assertEqual(1, stub.release_calls)
self.assertEqual(3, stub.release_until_calls)
self.assertEqual(1, stub.execute_calls)

self.assertEventually(check, timeout_ms=1000)

def test_fail_and_retry_during_reattach(self):
count = 0
Expand All @@ -204,9 +233,13 @@ def non_fatal():
for b in ite:
pass

self.assertEqual(2, stub.attach_calls)
self.assertEqual(2, stub.release_calls)
self.assertEqual(1, stub.execute_calls)
def check():
self.assertEqual(2, stub.attach_calls)
self.assertEqual(3, stub.release_until_calls)
self.assertEqual(1, stub.release_calls)
self.assertEqual(1, stub.execute_calls)

self.assertEventually(check, timeout_ms=1000)


class TestException(grpc.RpcError, grpc.Call):
Expand Down Expand Up @@ -257,6 +290,7 @@ def __init__(self, execute_ops=None, attach_ops=None):
# Call counters
self.execute_calls = 0
self.release_calls = 0
self.release_until_calls = 0
self.attach_calls = 0

def ExecutePlan(self, *args, **kwargs):
Expand All @@ -267,8 +301,12 @@ def ReattachExecute(self, *args, **kwargs):
self.attach_calls += 1
return self._attach_ops

def ReleaseExecute(self, *args, **kwargs):
self.release_calls += 1
def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs):
if req.HasField("release_all"):
self.release_calls += 1
elif req.HasField("release_until"):
print("increment")
self.release_until_calls += 1


class MockService:
Expand Down