diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 65a4b62..8b91e4c 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -147,9 +147,19 @@ def slice_device_count(self, slice_index: int) -> int: ) from error @classmethod - def _is_error_due_to_slice_down(cls, error: Exception) -> bool: - """Check if the error is due to slice down.""" - return_value = any( + def is_error_due_to_slice_down(cls, error: Exception) -> bool: + """Returns True if the error is due to slice down. + + The error types that are considered due to slice down are + jax.errors.JaxRuntimeError with the following error kind in the message: + - DATA_LOSS + - NOT_FOUND + - INTERNAL + + Args: + error: The error to check. + """ + return_value = isinstance(error, jax.errors.JaxRuntimeError) and any( error_type in str(error) for error_type in cls._ELASTIC_DOWN_ERROR_TYPES ) if return_value: @@ -221,7 +231,7 @@ def get_slice_availability(self) -> set[int]: f"Error with _simple_execution for slice_index={slice_index}." ) except jax.errors.JaxRuntimeError as error: - if not self._is_error_due_to_slice_down(error): + if not self.is_error_due_to_slice_down(error): raise _logger.info("slice_index=%s bad", slice_index) @@ -565,7 +575,7 @@ def maybe_reshard_down( handler_kwargs = {} while True: - if not self._is_error_due_to_slice_down(error): + if not self.is_error_due_to_slice_down(error): _logger.info( "Not resharding down because the error is not due to a slice down." )