Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,19 @@ def slice_device_count(self, slice_index: int) -> int:
) from error

@classmethod
def _is_error_due_to_slice_down(cls, error: Exception) -> bool:
"""Check if the error is due to slice down."""
return_value = any(
def is_error_due_to_slice_down(cls, error: Exception) -> bool:
"""Returns True if the error is due to slice down.

The error types that are considered due to slice down are
jax.errors.JaxRuntimeError with the following error kind in the message:
- DATA_LOSS
- NOT_FOUND
- INTERNAL

Args:
error: The error to check.
"""
return_value = isinstance(error, jax.errors.JaxRuntimeError) and any(
error_type in str(error) for error_type in cls._ELASTIC_DOWN_ERROR_TYPES
)
if return_value:
Expand Down Expand Up @@ -221,7 +231,7 @@ def get_slice_availability(self) -> set[int]:
f"Error with _simple_execution for slice_index={slice_index}."
)
except jax.errors.JaxRuntimeError as error:
if not self._is_error_due_to_slice_down(error):
if not self.is_error_due_to_slice_down(error):
raise
_logger.info("slice_index=%s bad", slice_index)

Expand Down Expand Up @@ -565,7 +575,7 @@ def maybe_reshard_down(
handler_kwargs = {}

while True:
if not self._is_error_due_to_slice_down(error):
if not self.is_error_due_to_slice_down(error):
_logger.info(
"Not resharding down because the error is not due to a slice down."
)
Expand Down