New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-45733][CONNECT][PYTHON] Support multiple retry policies #43591
Changes from 5 commits
9a400be
4c432ca
2995d08
35b1f65
871ce94
9c597b7
1c42788
f2c8595
df3472a
a11e7df
363e0b2
31b88bd
757bae4
0b3aaf1
3395757
8e79289
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,19 +20,17 @@ | |
] | ||
|
||
from pyspark.loose_version import LooseVersion | ||
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying | ||
from pyspark.sql.connect.utils import check_dependencies | ||
|
||
check_dependencies(__name__) | ||
|
||
import threading | ||
import os | ||
import platform | ||
import random | ||
import time | ||
import urllib.parse | ||
import uuid | ||
import sys | ||
from types import TracebackType | ||
from typing import ( | ||
Iterable, | ||
Iterator, | ||
|
@@ -45,9 +43,6 @@ | |
Set, | ||
NoReturn, | ||
cast, | ||
Callable, | ||
Generator, | ||
Type, | ||
TYPE_CHECKING, | ||
Sequence, | ||
) | ||
|
@@ -550,13 +545,11 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": | |
) | ||
|
||
|
||
class SparkConnectClient(object): | ||
""" | ||
Conceptually the remote spark session that communicates with the server | ||
""" | ||
class DefaultPolicy(RetryPolicy): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
@classmethod | ||
def retry_exception(cls, e: Exception) -> bool: | ||
def can_retry(self, e: Exception) -> bool: | ||
""" | ||
Helper function that is used to identify if an exception thrown by the server | ||
can be retried or not. | ||
|
@@ -572,6 +565,9 @@ def retry_exception(cls, e: Exception) -> bool: | |
True if the exception can be retried, False otherwise. | ||
|
||
""" | ||
if isinstance(e, RetryException): | ||
return True | ||
|
||
if not isinstance(e, grpc.RpcError): | ||
return False | ||
|
||
|
@@ -587,6 +583,12 @@ def retry_exception(cls, e: Exception) -> bool: | |
|
||
return False | ||
|
||
|
||
class SparkConnectClient(object): | ||
""" | ||
Conceptually the remote spark session that communicates with the server | ||
""" | ||
|
||
def __init__( | ||
self, | ||
connection: Union[str, ChannelBuilder], | ||
|
@@ -634,7 +636,10 @@ def __init__( | |
else ChannelBuilder(connection, channel_options) | ||
) | ||
self._user_id = None | ||
self._retry_policy = { | ||
self._known_retry_policies: Dict[str, RetryPolicy] = dict() | ||
self._retry_policies: List[RetryPolicy] = [] | ||
|
||
default_policy_args = { | ||
# Please synchronize changes here with Scala side | ||
# GrpcRetryHandler.scala | ||
# | ||
|
@@ -648,7 +653,11 @@ def __init__( | |
"min_jitter_threshold": 2000, | ||
} | ||
if retry_policy: | ||
self._retry_policy.update(retry_policy) | ||
default_policy_args.update(retry_policy) | ||
|
||
default_policy = DefaultPolicy(**default_policy_args) | ||
self.register_retry_policy(default_policy) | ||
self.set_retry_policies([default_policy.name]) | ||
|
||
if self._builder.session_id is None: | ||
# Generate a unique session ID for this client. This UUID must be unique to allow | ||
|
@@ -676,9 +685,7 @@ def __init__( | |
# Configure logging for the SparkConnect client. | ||
|
||
def _retrying(self) -> "Retrying": | ||
return Retrying( | ||
can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore | ||
) | ||
return Retrying(self._retry_policies) | ||
|
||
def disable_reattachable_execute(self) -> "SparkConnectClient": | ||
self._use_reattachable_execute = False | ||
|
@@ -688,6 +695,14 @@ def enable_reattachable_execute(self) -> "SparkConnectClient": | |
self._use_reattachable_execute = True | ||
return self | ||
|
||
def register_retry_policy(self, policy: RetryPolicy): | ||
if policy.name in self._known_retry_policies: | ||
raise ValueError("Already known policy") | ||
self._known_retry_policies[policy.name] = policy | ||
|
||
def set_retry_policies(self, policies: List[str]): | ||
self._retry_policies = [self._known_retry_policies[name] for name in policies] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need the functionality of registering policies that we are not going to be using? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can consider more options, but my rationale here was that it makes it easier for users to configure retry policies without needing to toss objects around. Policy consists not just of the class but also parameters set on the instance. For example, suppose you have a client which shipped with policies ["policyA", "policyB", "policyC"], and for some reason you aren't happy with policyB. Than it's easier to configure this in only one call without having to obtain objects for policyA, policyC. Also convenient to add you own PolicyD in the mix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just wondering if any user that is a "power" user enough to work with changing these policies would not be comfortable with tossing these objects around anyway. E.g. if they would want to tweak the parameters that you mention, they would anyway have to deal with the object, instantiate it with different parameters, register it and then set it. But I don't have a strong opinion here, I am also fine with keeping these extra APIs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
probably not really needs obtaining the actualy object, but rather import statement from somewhere, and a new instance construction Don't really object other ways to do it, just explaining why I did that :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made a change in API which now also allows passing object directly |
||
|
||
def register_udf( | ||
self, | ||
function: Any, | ||
|
@@ -1153,7 +1168,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None: | |
if self._use_reattachable_execute: | ||
# Don't use retryHandler - own retry handling is inside. | ||
generator = ExecutePlanResponseReattachableIterator( | ||
req, self._stub, self._retry_policy, self._builder.metadata() | ||
req, self._stub, self._retrying, self._builder.metadata() | ||
) | ||
for b in generator: | ||
handle_response(b) | ||
|
@@ -1266,7 +1281,7 @@ def handle_response( | |
if self._use_reattachable_execute: | ||
# Don't use retryHandler - own retry handling is inside. | ||
generator = ExecutePlanResponseReattachableIterator( | ||
req, self._stub, self._retry_policy, self._builder.metadata() | ||
req, self._stub, self._retrying, self._builder.metadata() | ||
) | ||
for b in generator: | ||
yield from handle_response(b) | ||
|
@@ -1598,145 +1613,3 @@ def cache_artifact(self, blob: bytes) -> str: | |
with attempt: | ||
return self._artifact_manager.cache_artifact(blob) | ||
raise SparkConnectException("Invalid state during retry exception handling.") | ||
|
||
|
||
class RetryState: | ||
""" | ||
Simple state helper that captures the state between retries of the exceptions. It | ||
keeps track of the last exception thrown and how many in total. When the task | ||
finishes successfully done() returns True. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._exception: Optional[BaseException] = None | ||
self._done = False | ||
self._count = 0 | ||
|
||
def set_exception(self, exc: BaseException) -> None: | ||
self._exception = exc | ||
self._count += 1 | ||
|
||
def throw(self) -> None: | ||
raise self.exception() | ||
|
||
def exception(self) -> BaseException: | ||
if self._exception is None: | ||
raise RuntimeError("No exception is set") | ||
return self._exception | ||
|
||
def set_done(self) -> None: | ||
self._done = True | ||
|
||
def count(self) -> int: | ||
return self._count | ||
|
||
def done(self) -> bool: | ||
return self._done | ||
|
||
|
||
class AttemptManager: | ||
""" | ||
Simple ContextManager that is used to capture the exception thrown inside the context. | ||
""" | ||
|
||
def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> None: | ||
self._retry_state = retry_state | ||
self._can_retry = check | ||
|
||
def __enter__(self) -> None: | ||
pass | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[TracebackType], | ||
) -> Optional[bool]: | ||
if isinstance(exc_val, BaseException): | ||
# Swallow the exception. | ||
if self._can_retry(exc_val) or isinstance(exc_val, RetryException): | ||
self._retry_state.set_exception(exc_val) | ||
return True | ||
# Bubble up the exception. | ||
return False | ||
else: | ||
self._retry_state.set_done() | ||
return None | ||
|
||
def is_first_try(self) -> bool: | ||
return self._retry_state._count == 0 | ||
|
||
|
||
class Retrying: | ||
""" | ||
This helper class is used as a generator together with a context manager to | ||
allow retrying exceptions in particular code blocks. The Retrying can be configured | ||
with a lambda function that is can be filtered what kind of exceptions should be | ||
retried. | ||
|
||
In addition, there are several parameters that are used to configure the exponential | ||
backoff behavior. | ||
|
||
An example to use this class looks like this: | ||
|
||
.. code-block:: python | ||
|
||
for attempt in Retrying(can_retry=lambda x: isinstance(x, TransientError)): | ||
with attempt: | ||
# do the work. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
max_retries: int, | ||
initial_backoff: int, | ||
max_backoff: int, | ||
backoff_multiplier: float, | ||
jitter: int, | ||
min_jitter_threshold: int, | ||
can_retry: Callable[..., bool] = lambda x: True, | ||
sleep: Callable[[float], None] = time.sleep, | ||
) -> None: | ||
self._can_retry = can_retry | ||
self._max_retries = max_retries | ||
self._initial_backoff = initial_backoff | ||
self._max_backoff = max_backoff | ||
self._backoff_multiplier = backoff_multiplier | ||
self._jitter = jitter | ||
self._min_jitter_threshold = min_jitter_threshold | ||
self._sleep = sleep | ||
|
||
def __iter__(self) -> Generator[AttemptManager, None, None]: | ||
""" | ||
Generator function to wrap the exception producing code block. | ||
|
||
Returns | ||
------- | ||
A generator that yields the current attempt. | ||
""" | ||
retry_state = RetryState() | ||
next_backoff: float = self._initial_backoff | ||
|
||
if self._max_retries < 0: | ||
raise ValueError("Can't have negative number of retries") | ||
|
||
while not retry_state.done() and retry_state.count() <= self._max_retries: | ||
# Do backoff | ||
if retry_state.count() > 0: | ||
# Randomize backoff for this iteration | ||
backoff = next_backoff | ||
next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier) | ||
|
||
if backoff >= self._min_jitter_threshold: | ||
backoff += random.uniform(0, self._jitter) | ||
|
||
logger.debug( | ||
f"Will retry call after {backoff} ms sleep (error: {retry_state.exception()})" | ||
) | ||
self._sleep(backoff / 1000.0) | ||
yield AttemptManager(self._can_retry, retry_state) | ||
|
||
if not retry_state.done(): | ||
# Exceeded number of retries, throw last exception we had | ||
retry_state.throw() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from pyspark.sql.connect.client.retries import Retrying | ||
from pyspark.sql.connect.utils import check_dependencies | ||
|
||
check_dependencies(__name__) | ||
|
@@ -22,7 +23,7 @@ | |
import warnings | ||
import uuid | ||
from collections.abc import Generator | ||
from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar | ||
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar | ||
from multiprocessing.pool import ThreadPool | ||
import os | ||
|
||
|
@@ -83,12 +84,12 @@ def __init__( | |
self, | ||
request: pb2.ExecutePlanRequest, | ||
stub: grpc_lib.SparkConnectServiceStub, | ||
retry_policy: Dict[str, Any], | ||
retrying: Callable[[], Retrying], | ||
metadata: Iterable[Tuple[str, str]], | ||
): | ||
ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary() | ||
self._request = request | ||
self._retry_policy = retry_policy | ||
self._retrying = retrying | ||
if request.operation_id: | ||
self._operation_id = request.operation_id | ||
else: | ||
|
@@ -143,17 +144,12 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse: | |
return ret | ||
|
||
def _has_next(self) -> bool: | ||
from pyspark.sql.connect.client.core import SparkConnectClient | ||
from pyspark.sql.connect.client.core import Retrying | ||
|
||
if self._result_complete: | ||
# After response complete response | ||
return False | ||
else: | ||
try: | ||
for attempt in Retrying( | ||
can_retry=SparkConnectClient.retry_exception, **self._retry_policy | ||
): | ||
for attempt in self._retrying(): | ||
with attempt: | ||
if self._current is None: | ||
try: | ||
|
@@ -199,16 +195,11 @@ def _release_until(self, until_response_id: str) -> None: | |
if self._result_complete: | ||
return | ||
|
||
from pyspark.sql.connect.client.core import SparkConnectClient | ||
from pyspark.sql.connect.client.core import Retrying | ||
|
||
request = self._create_release_execute_request(until_response_id) | ||
|
||
def target() -> None: | ||
try: | ||
for attempt in Retrying( | ||
can_retry=SparkConnectClient.retry_exception, **self._retry_policy | ||
): | ||
for attempt in self._retrying(): | ||
with attempt: | ||
self._stub.ReleaseExecute(request, metadata=self._metadata) | ||
except Exception as e: | ||
|
@@ -228,16 +219,11 @@ def _release_all(self) -> None: | |
if self._result_complete: | ||
return | ||
|
||
from pyspark.sql.connect.client.core import SparkConnectClient | ||
from pyspark.sql.connect.client.core import Retrying | ||
|
||
request = self._create_release_execute_request(None) | ||
|
||
def target() -> None: | ||
try: | ||
for attempt in Retrying( | ||
can_retry=SparkConnectClient.retry_exception, **self._retry_policy | ||
): | ||
for attempt in self._retrying(): | ||
with attempt: | ||
self._stub.ReleaseExecute(request, metadata=self._metadata) | ||
except Exception as e: | ||
|
@@ -335,6 +321,5 @@ def __del__(self) -> None: | |
|
||
class RetryException(Exception): | ||
""" | ||
An exception that can be thrown upstream when inside retry and which will be retryable | ||
regardless of policy. | ||
An exception that can be thrown upstream when inside retry and which is always retryable | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe move to retries.py? I think it shouldn't have landed in this file in the first place. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: DefaultRetryPolicy