Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45733][CONNECT][PYTHON] Support multiple retry policies #43591

Closed
wants to merge 16 commits into from
219 changes: 27 additions & 192 deletions python/pyspark/sql/connect/client/core.py
Expand Up @@ -19,20 +19,16 @@
"SparkConnectClient",
]


from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import threading
import os
import platform
import random
import time
import urllib.parse
import uuid
import sys
from types import TracebackType
from typing import (
Iterable,
Iterator,
Expand All @@ -45,9 +41,6 @@
Set,
NoReturn,
cast,
Callable,
Generator,
Type,
TYPE_CHECKING,
Sequence,
)
Expand All @@ -66,10 +59,8 @@
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.client.reattach import (
ExecutePlanResponseReattachableIterator,
RetryException,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
Expand Down Expand Up @@ -556,38 +547,6 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

@classmethod
def retry_exception(cls, e: Exception) -> bool:
"""
Helper function that is used to identify if an exception thrown by the server
can be retried or not.

Parameters
----------
e : Exception
The GRPC error as received from the server. Typed as Exception, because other exception
thrown during client processing can be passed here as well.

Returns
-------
True if the exception can be retried, False otherwise.

"""
if not isinstance(e, grpc.RpcError):
return False

if e.code() in [grpc.StatusCode.INTERNAL]:
msg = str(e)

# This error happens if another RPC preempts this RPC.
if "INVALID_CURSOR.DISCONNECTED" in msg:
return True

if e.code() == grpc.StatusCode.UNAVAILABLE:
return True

return False

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -635,7 +594,10 @@ def __init__(
else ChannelBuilder(connection, channel_options)
)
self._user_id = None
self._retry_policy = {
self._known_retry_policies: Dict[str, RetryPolicy] = dict()
self._retry_policies: List[RetryPolicy] = []

default_policy_args = {
# Please synchronize changes here with Scala side
# GrpcRetryHandler.scala
#
Expand All @@ -649,7 +611,10 @@ def __init__(
"min_jitter_threshold": 2000,
}
if retry_policy:
self._retry_policy.update(retry_policy)
default_policy_args.update(retry_policy)

default_policy = DefaultPolicy(**default_policy_args)
self.set_retry_policies([default_policy])

if self._builder.session_id is None:
# Generate a unique session ID for this client. This UUID must be unique to allow
Expand Down Expand Up @@ -677,9 +642,7 @@ def __init__(
# Configure logging for the SparkConnect client.

def _retrying(self) -> "Retrying":
return Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore
)
return Retrying(self._retry_policies)

def disable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = False
Expand All @@ -689,6 +652,20 @@ def enable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = True
return self

def set_retry_policies(self, policies: Iterable[RetryPolicy]):
"""
Sets list of policies to be used for retries.
I.e. set_retry_policies([DefaultPolicy(), CustomPolicy()]).

"""
self._retry_policies = list(policies)

def get_retry_policies(self) -> List[RetryPolicy]:
"""
Return list of currently used policies
"""
return list(self._retry_policies)

def register_udf(
self,
function: Any,
Expand Down Expand Up @@ -1154,7 +1131,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
handle_response(b)
Expand Down Expand Up @@ -1267,7 +1244,7 @@ def handle_response(
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
yield from handle_response(b)
Expand Down Expand Up @@ -1619,145 +1596,3 @@ def cache_artifact(self, blob: bytes) -> str:
with attempt:
return self._artifact_manager.cache_artifact(blob)
raise SparkConnectException("Invalid state during retry exception handling.")


class RetryState:
"""
Simple state helper that captures the state between retries of the exceptions. It
keeps track of the last exception thrown and how many in total. When the task
finishes successfully done() returns True.
"""

def __init__(self) -> None:
self._exception: Optional[BaseException] = None
self._done = False
self._count = 0

def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._count += 1

def throw(self) -> None:
raise self.exception()

def exception(self) -> BaseException:
if self._exception is None:
raise RuntimeError("No exception is set")
return self._exception

def set_done(self) -> None:
self._done = True

def count(self) -> int:
return self._count

def done(self) -> bool:
return self._done


class AttemptManager:
"""
Simple ContextManager that is used to capture the exception thrown inside the context.
"""

def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> None:
self._retry_state = retry_state
self._can_retry = check

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exc_val, BaseException):
# Swallow the exception.
if self._can_retry(exc_val) or isinstance(exc_val, RetryException):
self._retry_state.set_exception(exc_val)
return True
# Bubble up the exception.
return False
else:
self._retry_state.set_done()
return None

def is_first_try(self) -> bool:
return self._retry_state._count == 0


class Retrying:
"""
This helper class is used as a generator together with a context manager to
allow retrying exceptions in particular code blocks. The Retrying can be configured
with a lambda function that is can be filtered what kind of exceptions should be
retried.

In addition, there are several parameters that are used to configure the exponential
backoff behavior.

An example to use this class looks like this:

.. code-block:: python

for attempt in Retrying(can_retry=lambda x: isinstance(x, TransientError)):
with attempt:
# do the work.

"""

def __init__(
self,
max_retries: int,
initial_backoff: int,
max_backoff: int,
backoff_multiplier: float,
jitter: int,
min_jitter_threshold: int,
can_retry: Callable[..., bool] = lambda x: True,
sleep: Callable[[float], None] = time.sleep,
) -> None:
self._can_retry = can_retry
self._max_retries = max_retries
self._initial_backoff = initial_backoff
self._max_backoff = max_backoff
self._backoff_multiplier = backoff_multiplier
self._jitter = jitter
self._min_jitter_threshold = min_jitter_threshold
self._sleep = sleep

def __iter__(self) -> Generator[AttemptManager, None, None]:
"""
Generator function to wrap the exception producing code block.

Returns
-------
A generator that yields the current attempt.
"""
retry_state = RetryState()
next_backoff: float = self._initial_backoff

if self._max_retries < 0:
raise ValueError("Can't have negative number of retries")

while not retry_state.done() and retry_state.count() <= self._max_retries:
# Do backoff
if retry_state.count() > 0:
# Randomize backoff for this iteration
backoff = next_backoff
next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier)

if backoff >= self._min_jitter_threshold:
backoff += random.uniform(0, self._jitter)

logger.debug(
f"Will retry call after {backoff} ms sleep (error: {retry_state.exception()})"
)
self._sleep(backoff / 1000.0)
yield AttemptManager(self._can_retry, retry_state)

if not retry_state.done():
# Exceeded number of retries, throw last exception we had
retry_state.throw()