Skip to content

Commit b746f2e

Browse files
committed
fix(async): exclude all retryable errors from early-shutdown gate
The gate previously only excluded `ModelRateLimitError`, leaving `ModelTimeoutError`, `ModelInternalServerError`, and `ModelAPIConnectionError` to count toward the sliding-window error rate. Under provider degradation these errors cluster in time (concurrent in-flight requests time out together), so 5/10 in a row is easy and trips the gate even when salvage could recover the rows. Refs #575.
1 parent 93ae875 commit b746f2e

2 files changed

Lines changed: 104 additions & 3 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,9 +739,13 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
739739
trace.status = "ok"
740740

741741
except Exception as exc:
742-
if not isinstance(exc, ModelRateLimitError):
743-
self._check_error_rate(success=False)
744742
retryable = self._is_retryable(exc)
743+
# Only non-retryable errors (auth, schema, code bugs) count toward
744+
# the early-shutdown gate. Retryable errors (rate-limit, timeout,
745+
# transient 5xx, connection blips) cluster under provider degradation
746+
# and would otherwise trip the gate even when salvage could recover.
747+
if not retryable:
748+
self._check_error_rate(success=False)
745749
if not retryable and self._reporter:
746750
self._reporter.record_failure(task.column)
747751
if self._trace and trace:

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import asyncio
7+
from collections.abc import Callable
78
from typing import Any
89
from unittest.mock import MagicMock
910

@@ -31,7 +32,12 @@
3132
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker
3233
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
3334
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager
34-
from data_designer.engine.models.errors import ModelInternalServerError, ModelRateLimitError
35+
from data_designer.engine.models.errors import (
36+
ModelAPIConnectionError,
37+
ModelInternalServerError,
38+
ModelRateLimitError,
39+
ModelTimeoutError,
40+
)
3541
from data_designer.engine.resources.resource_provider import ResourceProvider
3642

3743
MODEL_ALIAS = "stub"
@@ -167,6 +173,33 @@ def generate(self, data: dict) -> dict:
167173
return data
168174

169175

176+
class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]):
177+
"""Generator that raises a parametrizable retryable error then succeeds."""
178+
179+
def __init__(
180+
self,
181+
*args: Any,
182+
error_factory: Callable[[], Exception],
183+
retryable_failures: int = 0,
184+
**kwargs: Any,
185+
) -> None:
186+
super().__init__(*args, **kwargs)
187+
self._error_factory = error_factory
188+
self._retryable_failures = retryable_failures
189+
self._calls = 0
190+
191+
@staticmethod
192+
def get_generation_strategy() -> GenerationStrategy:
193+
return GenerationStrategy.CELL_BY_CELL
194+
195+
def generate(self, data: dict) -> dict:
196+
self._calls += 1
197+
if self._calls <= self._retryable_failures:
198+
raise self._error_factory()
199+
data[self.config.name] = f"ok_{data.get('seed', '?')}"
200+
return data
201+
202+
170203
# -- Helper to build graph + scheduler ----------------------------------------
171204

172205

@@ -822,6 +855,70 @@ async def test_rate_limit_errors_do_not_trigger_early_shutdown() -> None:
822855
assert tracker.is_row_group_complete(0, 10, ["seed", "col"])
823856

824857

858+
@pytest.mark.parametrize(
859+
"error_factory",
860+
[
861+
pytest.param(lambda: ModelRateLimitError("429 Too Many Requests"), id="rate_limit"),
862+
pytest.param(lambda: ModelTimeoutError("read timeout"), id="timeout"),
863+
pytest.param(lambda: ModelInternalServerError("503 Service Unavailable"), id="internal_server"),
864+
pytest.param(lambda: ModelAPIConnectionError("connection reset"), id="api_connection"),
865+
],
866+
)
867+
@pytest.mark.asyncio(loop_scope="session")
868+
async def test_retryable_errors_do_not_trigger_early_shutdown(
869+
error_factory: Callable[[], Exception],
870+
) -> None:
871+
"""All retryable errors (rate-limit, timeout, 5xx, connection) bypass the early-shutdown gate.
872+
873+
Regression test for #575: clustered ``ModelTimeoutError`` during provider degradation
874+
used to trip the gate even though salvage could recover the rows.
875+
"""
876+
provider = _mock_provider()
877+
configs = [
878+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
879+
LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
880+
]
881+
strategies = {
882+
"seed": GenerationStrategy.FULL_COLUMN,
883+
"col": GenerationStrategy.CELL_BY_CELL,
884+
}
885+
generators = {
886+
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
887+
"col": MockRetryableErrorGenerator(
888+
config=_expr_config("col"),
889+
resource_provider=provider,
890+
error_factory=error_factory,
891+
retryable_failures=8,
892+
),
893+
}
894+
895+
graph = ExecutionGraph.create(configs, strategies)
896+
row_groups = [(0, 10)]
897+
tracker = CompletionTracker.with_graph(graph, row_groups)
898+
899+
storage = MagicMock()
900+
storage.dataset_name = "test"
901+
storage.get_file_paths.return_value = {}
902+
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
903+
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
904+
buffer_mgr = RowGroupBufferManager(storage)
905+
906+
scheduler = AsyncTaskScheduler(
907+
generators=generators,
908+
graph=graph,
909+
tracker=tracker,
910+
row_groups=row_groups,
911+
buffer_manager=buffer_mgr,
912+
shutdown_error_rate=0.5,
913+
shutdown_error_window=10,
914+
)
915+
await scheduler.run()
916+
917+
assert not scheduler._early_shutdown
918+
assert scheduler._recent_outcomes.count(False) == 0
919+
assert tracker.is_row_group_complete(0, 10, ["seed", "col"])
920+
921+
825922
@pytest.mark.asyncio(loop_scope="session")
826923
async def test_scheduler_on_before_checkpoint_callback() -> None:
827924
"""on_before_checkpoint is called before each row group is checkpointed."""

0 commit comments

Comments
 (0)