Skip to content

Commit

Permalink
[SPARK-45733][CONNECT][PYTHON] Support multiple retry policies
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support multiple retry policies defined at the same time. Each policy determines which error types it can retry and how exactly those should be spread out.

### Why are the changes needed?

Different error types should be treated differently For instance, networking connectivity errors and remote resources being initialized should be treated separately.

### Does this PR introduce _any_ user-facing change?
No (as long as user doesn't poke within client internals).

### How was this patch tested?
Unit tests, some hand testing.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43591 from cdkrot/SPARK-45733.

Authored-by: Alice Sayutina <alice.sayutina@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
cdkrot authored and HyukjinKwon committed Nov 13, 2023
1 parent ba0e098 commit ef240cd
Show file tree
Hide file tree
Showing 5 changed files with 468 additions and 334 deletions.
218 changes: 26 additions & 192 deletions python/pyspark/sql/connect/client/core.py
Expand Up @@ -19,20 +19,16 @@
"SparkConnectClient",
]


from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import threading
import os
import platform
import random
import time
import urllib.parse
import uuid
import sys
from types import TracebackType
from typing import (
Iterable,
Iterator,
Expand All @@ -45,9 +41,6 @@
Set,
NoReturn,
cast,
Callable,
Generator,
Type,
TYPE_CHECKING,
Sequence,
)
Expand All @@ -66,10 +59,8 @@
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.client.reattach import (
ExecutePlanResponseReattachableIterator,
RetryException,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
Expand Down Expand Up @@ -555,38 +546,6 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""

@classmethod
def retry_exception(cls, e: Exception) -> bool:
"""
Helper function that is used to identify if an exception thrown by the server
can be retried or not.
Parameters
----------
e : Exception
The GRPC error as received from the server. Typed as Exception, because other exception
thrown during client processing can be passed here as well.
Returns
-------
True if the exception can be retried, False otherwise.
"""
if not isinstance(e, grpc.RpcError):
return False

if e.code() in [grpc.StatusCode.INTERNAL]:
msg = str(e)

# This error happens if another RPC preempts this RPC.
if "INVALID_CURSOR.DISCONNECTED" in msg:
return True

if e.code() == grpc.StatusCode.UNAVAILABLE:
return True

return False

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -634,7 +593,9 @@ def __init__(
else ChannelBuilder(connection, channel_options)
)
self._user_id = None
self._retry_policy = {
self._retry_policies: List[RetryPolicy] = []

default_policy_args = {
# Please synchronize changes here with Scala side
# GrpcRetryHandler.scala
#
Expand All @@ -648,7 +609,10 @@ def __init__(
"min_jitter_threshold": 2000,
}
if retry_policy:
self._retry_policy.update(retry_policy)
default_policy_args.update(retry_policy)

default_policy = DefaultPolicy(**default_policy_args)
self.set_retry_policies([default_policy])

if self._builder.session_id is None:
# Generate a unique session ID for this client. This UUID must be unique to allow
Expand Down Expand Up @@ -680,9 +644,7 @@ def __init__(
self._server_session_id: Optional[str] = None

def _retrying(self) -> "Retrying":
return Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore
)
return Retrying(self._retry_policies)

def disable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = False
Expand All @@ -692,6 +654,20 @@ def enable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = True
return self

def set_retry_policies(self, policies: Iterable[RetryPolicy]) -> None:
"""
Sets list of policies to be used for retries.
I.e. set_retry_policies([DefaultPolicy(), CustomPolicy()]).
"""
self._retry_policies = list(policies)

def get_retry_policies(self) -> List[RetryPolicy]:
"""
Return list of currently used policies
"""
return list(self._retry_policies)

def register_udf(
self,
function: Any,
Expand Down Expand Up @@ -1152,7 +1128,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
handle_response(b)
Expand Down Expand Up @@ -1262,7 +1238,7 @@ def handle_response(
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
yield from handle_response(b)
Expand Down Expand Up @@ -1641,145 +1617,3 @@ def _verify_response_integrity(
else:
# Update the server side session ID.
self._server_session_id = response.server_side_session_id


class RetryState:
"""
Simple state helper that captures the state between retries of the exceptions. It
keeps track of the last exception thrown and how many in total. When the task
finishes successfully done() returns True.
"""

def __init__(self) -> None:
self._exception: Optional[BaseException] = None
self._done = False
self._count = 0

def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._count += 1

def throw(self) -> None:
raise self.exception()

def exception(self) -> BaseException:
if self._exception is None:
raise RuntimeError("No exception is set")
return self._exception

def set_done(self) -> None:
self._done = True

def count(self) -> int:
return self._count

def done(self) -> bool:
return self._done


class AttemptManager:
"""
Simple ContextManager that is used to capture the exception thrown inside the context.
"""

def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> None:
self._retry_state = retry_state
self._can_retry = check

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exc_val, BaseException):
# Swallow the exception.
if self._can_retry(exc_val) or isinstance(exc_val, RetryException):
self._retry_state.set_exception(exc_val)
return True
# Bubble up the exception.
return False
else:
self._retry_state.set_done()
return None

def is_first_try(self) -> bool:
return self._retry_state._count == 0


class Retrying:
"""
This helper class is used as a generator together with a context manager to
allow retrying exceptions in particular code blocks. The Retrying can be configured
with a lambda function that is can be filtered what kind of exceptions should be
retried.
In addition, there are several parameters that are used to configure the exponential
backoff behavior.
An example to use this class looks like this:
.. code-block:: python
for attempt in Retrying(can_retry=lambda x: isinstance(x, TransientError)):
with attempt:
# do the work.
"""

def __init__(
self,
max_retries: int,
initial_backoff: int,
max_backoff: int,
backoff_multiplier: float,
jitter: int,
min_jitter_threshold: int,
can_retry: Callable[..., bool] = lambda x: True,
sleep: Callable[[float], None] = time.sleep,
) -> None:
self._can_retry = can_retry
self._max_retries = max_retries
self._initial_backoff = initial_backoff
self._max_backoff = max_backoff
self._backoff_multiplier = backoff_multiplier
self._jitter = jitter
self._min_jitter_threshold = min_jitter_threshold
self._sleep = sleep

def __iter__(self) -> Generator[AttemptManager, None, None]:
"""
Generator function to wrap the exception producing code block.
Returns
-------
A generator that yields the current attempt.
"""
retry_state = RetryState()
next_backoff: float = self._initial_backoff

if self._max_retries < 0:
raise ValueError("Can't have negative number of retries")

while not retry_state.done() and retry_state.count() <= self._max_retries:
# Do backoff
if retry_state.count() > 0:
# Randomize backoff for this iteration
backoff = next_backoff
next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier)

if backoff >= self._min_jitter_threshold:
backoff += random.uniform(0, self._jitter)

logger.debug(
f"Will retry call after {backoff} ms sleep (error: {retry_state.exception()})"
)
self._sleep(backoff / 1000.0)
yield AttemptManager(self._can_retry, retry_state)

if not retry_state.done():
# Exceeded number of retries, throw last exception we had
retry_state.throw()
35 changes: 7 additions & 28 deletions python/pyspark/sql/connect/client/reattach.py
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.client.retries import Retrying, RetryException
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand All @@ -22,7 +23,7 @@
import warnings
import uuid
from collections.abc import Generator
from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from multiprocessing.pool import ThreadPool
import os

Expand Down Expand Up @@ -83,12 +84,12 @@ def __init__(
self,
request: pb2.ExecutePlanRequest,
stub: grpc_lib.SparkConnectServiceStub,
retry_policy: Dict[str, Any],
retrying: Callable[[], Retrying],
metadata: Iterable[Tuple[str, str]],
):
ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary()
self._request = request
self._retry_policy = retry_policy
self._retrying = retrying
if request.operation_id:
self._operation_id = request.operation_id
else:
Expand Down Expand Up @@ -143,17 +144,12 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse:
return ret

def _has_next(self) -> bool:
from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

if self._result_complete:
# After response complete response
return False
else:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
if self._current is None:
try:
Expand Down Expand Up @@ -199,16 +195,11 @@ def _release_until(self, until_response_id: str) -> None:
if self._result_complete:
return

from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

request = self._create_release_execute_request(until_response_id)

def target() -> None:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request, metadata=self._metadata)
except Exception as e:
Expand All @@ -228,16 +219,11 @@ def _release_all(self) -> None:
if self._result_complete:
return

from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

request = self._create_release_execute_request(None)

def target() -> None:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request, metadata=self._metadata)
except Exception as e:
Expand Down Expand Up @@ -331,10 +317,3 @@ def close(self) -> None:

def __del__(self) -> None:
return self.close()


class RetryException(Exception):
"""
An exception that can be thrown upstream when inside retry and which will be retryable
regardless of policy.
"""

0 comments on commit ef240cd

Please sign in to comment.