From c88deea3c8d02a9a0290f0e80c369c6e59529677 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 2 Sep 2025 10:44:56 -0700 Subject: [PATCH] Treat additional error types as potential slice down issues. This change separates the set of `JaxRuntimeError` types that are considered indicative of a slice being down into `DATA_LOSS` and additional types. `DEADLINE_EXCEEDED`, `NOT_FOUND`, and `INTERNAL` are now treated as "may or may not" be related to slice down but still returning true. PiperOrigin-RevId: 802205795 --- pathwaysutils/elastic/manager.py | 42 +++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index cf928bb..3986959 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -70,6 +70,9 @@ class Manager: _SIMPLE_EXECUTION_TEST_VALUE = 100 _ELASTIC_DOWN_ERROR_TYPES = [ "DATA_LOSS", + ] + _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ + "DEADLINE_EXCEEDED", "NOT_FOUND", "INTERNAL", ] @@ -168,24 +171,45 @@ def is_error_due_to_slice_down(self, error: Exception) -> bool: The error types that are considered due to slice down are jax.errors.JaxRuntimeError with the following error kind in the message: - DATA_LOSS + - DEADLINE_EXCEEDED - NOT_FOUND - INTERNAL Args: error: The error to check. """ - return_value = isinstance(error, jax.errors.JaxRuntimeError) and any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ERROR_TYPES - ) - if return_value: - _logger.info("Caught an error due to slice down") - else: + error_due_to_slice_down = False + traceback_logging_level = logging.DEBUG + + if isinstance(error, jax.errors.JaxRuntimeError): + if any( + error_type in str(error) + for error_type in self._ELASTIC_DOWN_ERROR_TYPES + ): + _logger.info("Caught an error due to slice down") + + error_due_to_slice_down = True + + elif any( + error_type in str(error) + for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES + ): + _logger.warning( + "Caught an error due that may or may not be due to slice down. This" + " error will be treated as due to slice down." + ) + traceback_logging_level = logging.WARNING + + error_due_to_slice_down = True + + if not error_due_to_slice_down: _logger.info("Caught an error not due to slice down") - _logger.debug("\n".join(traceback.format_exception(error))) + _logger.log( + traceback_logging_level, "\n".join(traceback.format_exception(error)) + ) - return return_value + return error_due_to_slice_down def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: """Simple execution to test if a slice is available.