|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | 6 | import asyncio |
| 7 | +from collections.abc import Callable |
7 | 8 | from typing import Any |
8 | 9 | from unittest.mock import MagicMock |
9 | 10 |
|
|
31 | 32 | from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker |
32 | 33 | from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph |
33 | 34 | from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager |
34 | | -from data_designer.engine.models.errors import ModelInternalServerError, ModelRateLimitError |
| 35 | +from data_designer.engine.models.errors import ( |
| 36 | + ModelAPIConnectionError, |
| 37 | + ModelInternalServerError, |
| 38 | + ModelRateLimitError, |
| 39 | + ModelTimeoutError, |
| 40 | +) |
35 | 41 | from data_designer.engine.resources.resource_provider import ResourceProvider |
36 | 42 |
|
37 | 43 | MODEL_ALIAS = "stub" |
@@ -167,6 +173,33 @@ def generate(self, data: dict) -> dict: |
167 | 173 | return data |
168 | 174 |
|
169 | 175 |
|
| 176 | +class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]): |
| 177 | + """Generator that raises a parametrizable retryable error then succeeds.""" |
| 178 | + |
| 179 | + def __init__( |
| 180 | + self, |
| 181 | + *args: Any, |
| 182 | + error_factory: Callable[[], Exception], |
| 183 | + retryable_failures: int = 0, |
| 184 | + **kwargs: Any, |
| 185 | + ) -> None: |
| 186 | + super().__init__(*args, **kwargs) |
| 187 | + self._error_factory = error_factory |
| 188 | + self._retryable_failures = retryable_failures |
| 189 | + self._calls = 0 |
| 190 | + |
| 191 | + @staticmethod |
| 192 | + def get_generation_strategy() -> GenerationStrategy: |
| 193 | + return GenerationStrategy.CELL_BY_CELL |
| 194 | + |
| 195 | + def generate(self, data: dict) -> dict: |
| 196 | + self._calls += 1 |
| 197 | + if self._calls <= self._retryable_failures: |
| 198 | + raise self._error_factory() |
| 199 | + data[self.config.name] = f"ok_{data.get('seed', '?')}" |
| 200 | + return data |
| 201 | + |
| 202 | + |
170 | 203 | # -- Helper to build graph + scheduler ---------------------------------------- |
171 | 204 |
|
172 | 205 |
|
@@ -822,6 +855,70 @@ async def test_rate_limit_errors_do_not_trigger_early_shutdown() -> None: |
822 | 855 | assert tracker.is_row_group_complete(0, 10, ["seed", "col"]) |
823 | 856 |
|
824 | 857 |
|
| 858 | +@pytest.mark.parametrize( |
| 859 | + "error_factory", |
| 860 | + [ |
| 861 | + pytest.param(lambda: ModelRateLimitError("429 Too Many Requests"), id="rate_limit"), |
| 862 | + pytest.param(lambda: ModelTimeoutError("read timeout"), id="timeout"), |
| 863 | + pytest.param(lambda: ModelInternalServerError("503 Service Unavailable"), id="internal_server"), |
| 864 | + pytest.param(lambda: ModelAPIConnectionError("connection reset"), id="api_connection"), |
| 865 | + ], |
| 866 | +) |
| 867 | +@pytest.mark.asyncio(loop_scope="session") |
| 868 | +async def test_retryable_errors_do_not_trigger_early_shutdown( |
| 869 | + error_factory: Callable[[], Exception], |
| 870 | +) -> None: |
| 871 | + """All retryable errors (rate-limit, timeout, 5xx, connection) bypass the early-shutdown gate. |
| 872 | +
|
| 873 | + Regression test for #575: clustered ``ModelTimeoutError`` during provider degradation |
| 874 | + used to trip the gate even though salvage could recover the rows. |
| 875 | + """ |
| 876 | + provider = _mock_provider() |
| 877 | + configs = [ |
| 878 | + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), |
| 879 | + LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), |
| 880 | + ] |
| 881 | + strategies = { |
| 882 | + "seed": GenerationStrategy.FULL_COLUMN, |
| 883 | + "col": GenerationStrategy.CELL_BY_CELL, |
| 884 | + } |
| 885 | + generators = { |
| 886 | + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), |
| 887 | + "col": MockRetryableErrorGenerator( |
| 888 | + config=_expr_config("col"), |
| 889 | + resource_provider=provider, |
| 890 | + error_factory=error_factory, |
| 891 | + retryable_failures=8, |
| 892 | + ), |
| 893 | + } |
| 894 | + |
| 895 | + graph = ExecutionGraph.create(configs, strategies) |
| 896 | + row_groups = [(0, 10)] |
| 897 | + tracker = CompletionTracker.with_graph(graph, row_groups) |
| 898 | + |
| 899 | + storage = MagicMock() |
| 900 | + storage.dataset_name = "test" |
| 901 | + storage.get_file_paths.return_value = {} |
| 902 | + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" |
| 903 | + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" |
| 904 | + buffer_mgr = RowGroupBufferManager(storage) |
| 905 | + |
| 906 | + scheduler = AsyncTaskScheduler( |
| 907 | + generators=generators, |
| 908 | + graph=graph, |
| 909 | + tracker=tracker, |
| 910 | + row_groups=row_groups, |
| 911 | + buffer_manager=buffer_mgr, |
| 912 | + shutdown_error_rate=0.5, |
| 913 | + shutdown_error_window=10, |
| 914 | + ) |
| 915 | + await scheduler.run() |
| 916 | + |
| 917 | + assert not scheduler._early_shutdown |
| 918 | + assert scheduler._recent_outcomes.count(False) == 0 |
| 919 | + assert tracker.is_row_group_complete(0, 10, ["seed", "col"]) |
| 920 | + |
| 921 | + |
825 | 922 | @pytest.mark.asyncio(loop_scope="session") |
826 | 923 | async def test_scheduler_on_before_checkpoint_callback() -> None: |
827 | 924 | """on_before_checkpoint is called before each row group is checkpointed.""" |
|
0 commit comments