Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class Manager:
_SIMPLE_EXECUTION_TEST_VALUE = 100
_ELASTIC_DOWN_ERROR_TYPES = [
"DATA_LOSS",
]
_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [
"DEADLINE_EXCEEDED",
"NOT_FOUND",
"INTERNAL",
]
Expand Down Expand Up @@ -168,24 +171,45 @@ def is_error_due_to_slice_down(self, error: Exception) -> bool:
The error types that are considered due to slice down are
jax.errors.JaxRuntimeError with the following error kind in the message:
- DATA_LOSS
- DEADLINE_EXCEEDED
- NOT_FOUND
- INTERNAL

Args:
error: The error to check.
"""
return_value = isinstance(error, jax.errors.JaxRuntimeError) and any(
error_type in str(error)
for error_type in self._ELASTIC_DOWN_ERROR_TYPES
)
if return_value:
_logger.info("Caught an error due to slice down")
else:
error_due_to_slice_down = False
traceback_logging_level = logging.DEBUG

if isinstance(error, jax.errors.JaxRuntimeError):
if any(
error_type in str(error)
for error_type in self._ELASTIC_DOWN_ERROR_TYPES
):
_logger.info("Caught an error due to slice down")

error_due_to_slice_down = True

elif any(
error_type in str(error)
for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
):
_logger.warning(
"Caught an error due that may or may not be due to slice down. This"
" error will be treated as due to slice down."
)
traceback_logging_level = logging.WARNING

error_due_to_slice_down = True

if not error_due_to_slice_down:
_logger.info("Caught an error not due to slice down")

_logger.debug("\n".join(traceback.format_exception(error)))
_logger.log(
traceback_logging_level, "\n".join(traceback.format_exception(error))
)

return return_value
return error_due_to_slice_down

def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
"""Simple execution to test if a slice is available.
Expand Down