Skip to content

Commit

Permalink
[SPARK-45922][CONNECT][CLIENT] Minor retries refactoring (follow-up t…
Browse files Browse the repository at this point in the history
…o multiple policies)

### What changes were proposed in this pull request?

Follow up to #43591.

Refactor default policy arguments into being an arguments on the class, not within core.py

### Why are the changes needed?
General refactoring, also makes it easier for other policies to derive.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing coverage

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43800 from cdkrot/SPARK-45922.

Authored-by: Alice Sayutina <alice.sayutina@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
cdkrot authored and HyukjinKwon committed Nov 25, 2023
1 parent 132bb63 commit 2f6a38c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
Expand Up @@ -55,7 +55,7 @@ object RetryPolicy {
def defaultPolicy(): RetryPolicy = RetryPolicy(
name = "DefaultPolicy",
// Please synchronize changes here with Python side:
// pyspark/sql/connect/client/core.py
// pyspark/sql/connect/client/retries.py
//
// Note: these constants are selected so that the maximum tolerated wait is guaranteed
// to be at least 10 minutes
Expand Down
19 changes: 2 additions & 17 deletions python/pyspark/sql/connect/client/core.py
Expand Up @@ -595,23 +595,8 @@ def __init__(
self._user_id = None
self._retry_policies: List[RetryPolicy] = []

default_policy_args = {
# Please synchronize changes here with Scala side
# GrpcRetryHandler.scala
#
# Note: the number of retries is selected so that the maximum tolerated wait
# is guaranteed to be at least 10 minutes
"max_retries": 15,
"backoff_multiplier": 4.0,
"initial_backoff": 50,
"max_backoff": 60000,
"jitter": 500,
"min_jitter_threshold": 2000,
}
if retry_policy:
default_policy_args.update(retry_policy)

default_policy = DefaultPolicy(**default_policy_args)
retry_policy_args = retry_policy or dict()
default_policy = DefaultPolicy(**retry_policy_args)
self.set_retry_policies([default_policy])

if self._builder.session_id is None:
Expand Down
37 changes: 32 additions & 5 deletions python/pyspark/sql/connect/client/retries.py
Expand Up @@ -185,6 +185,9 @@ def __init__(
self._done = False

def can_retry(self, exception: BaseException) -> bool:
if isinstance(exception, RetryException):
return True

return any(policy.can_retry(exception) for policy in self._policies)

def accept_exception(self, exception: BaseException) -> bool:
Expand All @@ -204,8 +207,12 @@ def _last_exception(self) -> BaseException:
def _wait(self) -> None:
exception = self._last_exception()

# Attempt to find a policy to wait with
if isinstance(exception, RetryException):
# Considered immediately retriable
logger.debug(f"Got error: {repr(exception)}. Retrying.")
return

# Attempt to find a policy to wait with
for policy in self._policies:
if not policy.can_retry(exception):
continue
Expand Down Expand Up @@ -244,12 +251,34 @@ def __iter__(self) -> Generator[AttemptManager, None, None]:
class RetryException(Exception):
"""
An exception that can be thrown upstream when inside retry and which is always retryable
even without policies
"""


class DefaultPolicy(RetryPolicy):
def __init__(self, **kwargs): # type: ignore[no-untyped-def]
super().__init__(**kwargs)
# Please synchronize changes here with Scala side in
# org.apache.spark.sql.connect.client.RetryPolicy
#
# Note: the number of retries is selected so that the maximum tolerated wait
# is guaranteed to be at least 10 minutes

def __init__(
self,
max_retries: Optional[int] = 15,
backoff_multiplier: float = 4.0,
initial_backoff: int = 50,
max_backoff: Optional[int] = 60000,
jitter: int = 500,
min_jitter_threshold: int = 2000,
):
super().__init__(
max_retries=max_retries,
backoff_multiplier=backoff_multiplier,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
jitter=jitter,
min_jitter_threshold=min_jitter_threshold,
)

def can_retry(self, e: BaseException) -> bool:
"""
Expand All @@ -267,8 +296,6 @@ def can_retry(self, e: BaseException) -> bool:
True if the exception can be retried, False otherwise.
"""
if isinstance(e, RetryException):
return True

if not isinstance(e, grpc.RpcError):
return False
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/tests/connect/client/test_client.py
Expand Up @@ -31,7 +31,6 @@
from pyspark.sql.connect.client.retries import (
Retrying,
DefaultPolicy,
RetryException,
RetriesExceeded,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
Expand Down Expand Up @@ -111,7 +110,7 @@ def sleep(t):
try:
for attempt in Retrying(client._retry_policies, sleep=sleep):
with attempt:
raise RetryException()
raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE)
except RetriesExceeded:
pass

Expand Down

0 comments on commit 2f6a38c

Please sign in to comment.