Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45733][CONNECT][PYTHON] Support multiple retry policies #43591

Closed
wants to merge 16 commits into from
195 changes: 34 additions & 161 deletions python/pyspark/sql/connect/client/core.py
Expand Up @@ -20,19 +20,17 @@
]

from pyspark.loose_version import LooseVersion
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import threading
import os
import platform
import random
import time
import urllib.parse
import uuid
import sys
from types import TracebackType
from typing import (
Iterable,
Iterator,
Expand All @@ -45,9 +43,6 @@
Set,
NoReturn,
cast,
Callable,
Generator,
Type,
TYPE_CHECKING,
Sequence,
)
Expand Down Expand Up @@ -550,13 +545,11 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
"""
class DefaultPolicy(RetryPolicy):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: DefaultRetryPolicy

def __init__(self, **kwargs):
super().__init__(**kwargs)

@classmethod
def retry_exception(cls, e: Exception) -> bool:
def can_retry(self, e: Exception) -> bool:
"""
Helper function that is used to identify if an exception thrown by the server
can be retried or not.
Expand All @@ -572,6 +565,9 @@ def retry_exception(cls, e: Exception) -> bool:
True if the exception can be retried, False otherwise.

"""
if isinstance(e, RetryException):
return True

if not isinstance(e, grpc.RpcError):
return False

Expand All @@ -587,6 +583,12 @@ def retry_exception(cls, e: Exception) -> bool:

return False


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
"""

def __init__(
self,
connection: Union[str, ChannelBuilder],
Expand Down Expand Up @@ -634,7 +636,10 @@ def __init__(
else ChannelBuilder(connection, channel_options)
)
self._user_id = None
self._retry_policy = {
self._known_retry_policies: Dict[str, RetryPolicy] = dict()
self._retry_policies: List[RetryPolicy] = []

default_policy_args = {
# Please synchronize changes here with Scala side
# GrpcRetryHandler.scala
#
Expand All @@ -648,7 +653,11 @@ def __init__(
"min_jitter_threshold": 2000,
}
if retry_policy:
self._retry_policy.update(retry_policy)
default_policy_args.update(retry_policy)

default_policy = DefaultPolicy(**default_policy_args)
self.register_retry_policy(default_policy)
self.set_retry_policies([default_policy.name])

if self._builder.session_id is None:
# Generate a unique session ID for this client. This UUID must be unique to allow
Expand Down Expand Up @@ -676,9 +685,7 @@ def __init__(
# Configure logging for the SparkConnect client.

def _retrying(self) -> "Retrying":
return Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore
)
return Retrying(self._retry_policies)

def disable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = False
Expand All @@ -688,6 +695,14 @@ def enable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = True
return self

def register_retry_policy(self, policy: RetryPolicy):
if policy.name in self._known_retry_policies:
raise ValueError("Already known policy")
self._known_retry_policies[policy.name] = policy

def set_retry_policies(self, policies: List[str]):
self._retry_policies = [self._known_retry_policies[name] for name in policies]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need the functionality of registering policies that we are not going to be using?
is there a use case for adding and removing policies?
unless there's a good reason to have a separate register and set, maybe simplify it and just set is enough?

Copy link
Contributor Author

@cdkrot cdkrot Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consider more options, but my rationale here was that it makes it easier for users to configure retry policies without needing to toss objects around. Policy consists not just of the class but also parameters set on the instance.

For example, suppose you have a client which shipped with policies ["policyA", "policyB", "policyC"], and for some reason you aren't happy with policyB. Than it's easier to configure this in only one call without having to obtain objects for policyA, policyC. Also convenient to add you own PolicyD in the mix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering if any user that is a "power" user enough to work with changing these policies would not be comfortable with tossing these objects around anyway. E.g. if they would want to tweak the parameters that you mention, they would anyway have to deal with the object, instantiate it with different parameters, register it and then set it.
For such a power user, it actually may be cleaner to just work with these objects directly instead of having a two step registration and setting.

But I don't have a strong opinion here, I am also fine with keeping these extra APIs.

Copy link
Contributor Author

@cdkrot cdkrot Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they would anyway have to deal with the object

probably not really needs obtaining the actualy object, but rather import statement from somewhere, and a new instance construction

Don't really object other ways to do it, just explaining why I did that :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a change in API which now also allows passing object directly


def register_udf(
self,
function: Any,
Expand Down Expand Up @@ -1153,7 +1168,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
handle_response(b)
Expand Down Expand Up @@ -1266,7 +1281,7 @@ def handle_response(
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
yield from handle_response(b)
Expand Down Expand Up @@ -1598,145 +1613,3 @@ def cache_artifact(self, blob: bytes) -> str:
with attempt:
return self._artifact_manager.cache_artifact(blob)
raise SparkConnectException("Invalid state during retry exception handling.")


class RetryState:
"""
Simple state helper that captures the state between retries of the exceptions. It
keeps track of the last exception thrown and how many in total. When the task
finishes successfully done() returns True.
"""

def __init__(self) -> None:
self._exception: Optional[BaseException] = None
self._done = False
self._count = 0

def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._count += 1

def throw(self) -> None:
raise self.exception()

def exception(self) -> BaseException:
if self._exception is None:
raise RuntimeError("No exception is set")
return self._exception

def set_done(self) -> None:
self._done = True

def count(self) -> int:
return self._count

def done(self) -> bool:
return self._done


class AttemptManager:
"""
Simple ContextManager that is used to capture the exception thrown inside the context.
"""

def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> None:
self._retry_state = retry_state
self._can_retry = check

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
if isinstance(exc_val, BaseException):
# Swallow the exception.
if self._can_retry(exc_val) or isinstance(exc_val, RetryException):
self._retry_state.set_exception(exc_val)
return True
# Bubble up the exception.
return False
else:
self._retry_state.set_done()
return None

def is_first_try(self) -> bool:
return self._retry_state._count == 0


class Retrying:
"""
This helper class is used as a generator together with a context manager to
allow retrying exceptions in particular code blocks. The Retrying can be configured
with a lambda function that is can be filtered what kind of exceptions should be
retried.

In addition, there are several parameters that are used to configure the exponential
backoff behavior.

An example to use this class looks like this:

.. code-block:: python

for attempt in Retrying(can_retry=lambda x: isinstance(x, TransientError)):
with attempt:
# do the work.

"""

def __init__(
self,
max_retries: int,
initial_backoff: int,
max_backoff: int,
backoff_multiplier: float,
jitter: int,
min_jitter_threshold: int,
can_retry: Callable[..., bool] = lambda x: True,
sleep: Callable[[float], None] = time.sleep,
) -> None:
self._can_retry = can_retry
self._max_retries = max_retries
self._initial_backoff = initial_backoff
self._max_backoff = max_backoff
self._backoff_multiplier = backoff_multiplier
self._jitter = jitter
self._min_jitter_threshold = min_jitter_threshold
self._sleep = sleep

def __iter__(self) -> Generator[AttemptManager, None, None]:
"""
Generator function to wrap the exception producing code block.

Returns
-------
A generator that yields the current attempt.
"""
retry_state = RetryState()
next_backoff: float = self._initial_backoff

if self._max_retries < 0:
raise ValueError("Can't have negative number of retries")

while not retry_state.done() and retry_state.count() <= self._max_retries:
# Do backoff
if retry_state.count() > 0:
# Randomize backoff for this iteration
backoff = next_backoff
next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier)

if backoff >= self._min_jitter_threshold:
backoff += random.uniform(0, self._jitter)

logger.debug(
f"Will retry call after {backoff} ms sleep (error: {retry_state.exception()})"
)
self._sleep(backoff / 1000.0)
yield AttemptManager(self._can_retry, retry_state)

if not retry_state.done():
# Exceeded number of retries, throw last exception we had
retry_state.throw()
31 changes: 8 additions & 23 deletions python/pyspark/sql/connect/client/reattach.py
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.client.retries import Retrying
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand All @@ -22,7 +23,7 @@
import warnings
import uuid
from collections.abc import Generator
from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar
from multiprocessing.pool import ThreadPool
import os

Expand Down Expand Up @@ -83,12 +84,12 @@ def __init__(
self,
request: pb2.ExecutePlanRequest,
stub: grpc_lib.SparkConnectServiceStub,
retry_policy: Dict[str, Any],
retrying: Callable[[], Retrying],
metadata: Iterable[Tuple[str, str]],
):
ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary()
self._request = request
self._retry_policy = retry_policy
self._retrying = retrying
if request.operation_id:
self._operation_id = request.operation_id
else:
Expand Down Expand Up @@ -143,17 +144,12 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse:
return ret

def _has_next(self) -> bool:
from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

if self._result_complete:
# After response complete response
return False
else:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
if self._current is None:
try:
Expand Down Expand Up @@ -199,16 +195,11 @@ def _release_until(self, until_response_id: str) -> None:
if self._result_complete:
return

from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

request = self._create_release_execute_request(until_response_id)

def target() -> None:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request, metadata=self._metadata)
except Exception as e:
Expand All @@ -228,16 +219,11 @@ def _release_all(self) -> None:
if self._result_complete:
return

from pyspark.sql.connect.client.core import SparkConnectClient
from pyspark.sql.connect.client.core import Retrying

request = self._create_release_execute_request(None)

def target() -> None:
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request, metadata=self._metadata)
except Exception as e:
Expand Down Expand Up @@ -335,6 +321,5 @@ def __del__(self) -> None:

class RetryException(Exception):
"""
An exception that can be thrown upstream when inside retry and which will be retryable
regardless of policy.
An exception that can be thrown upstream when inside retry and which is always retryable
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe move to retries.py? I think it shouldn't have landed in this file in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved