Skip to content

Commit 6a85d1e

Browse files
committed
feat(async): WARN log when provider showing degraded performance
Diagnostic A/Bs against build.nvidia.com showed runs failing silently under provider degradation - no log indication that retryable errors were piling up until the early-shutdown gate fired (or, post-fix, until salvage exhaustion). Surfacing this earlier helps users distinguish "DataDesigner is broken" from "the upstream provider is slow today." Tracks a separate sliding window over retryable-vs-not for every task outcome (independent of the early-shutdown gate's window) and emits a throttled WARN when the rolling fraction crosses the threshold. Refs #575.
1 parent b746f2e commit 6a85d1e

2 files changed

Lines changed: 160 additions & 0 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
DEFAULT_TASK_POOL_SIZE: int = 256
4848
LLM_WAIT_POOL_MULTIPLIER: int = 2
4949

50+
# Degraded-provider WARN: emit at most one warning per interval when the
51+
# rolling fraction of retryable errors exceeds the threshold. Distinct from
52+
# the early-shutdown gate (which fires on non-retryable errors).
53+
DEGRADED_WARN_RATE: float = 0.5
54+
DEGRADED_WARN_WINDOW: int = 20
55+
DEGRADED_WARN_INTERVAL_S: float = 60.0
56+
5057
_RETRYABLE_MODEL_ERRORS = (
5158
ModelRateLimitError,
5259
ModelTimeoutError,
@@ -105,6 +112,9 @@ def __init__(
105112
shutdown_error_rate: float = 0.5,
106113
shutdown_error_window: int = 10,
107114
disable_early_shutdown: bool = False,
115+
degraded_warn_rate: float = DEGRADED_WARN_RATE,
116+
degraded_warn_window: int = DEGRADED_WARN_WINDOW,
117+
degraded_warn_interval_s: float = DEGRADED_WARN_INTERVAL_S,
108118
trace: bool = False,
109119
num_records: int = 0,
110120
buffer_size: int = 0,
@@ -177,6 +187,14 @@ def __init__(
177187
self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window)
178188
self._all_rgs_admitted = False
179189

190+
# Degraded-provider WARN: separate window tracking retryable-vs-not for
191+
# every outcome (success or failure), throttled to one log per interval.
192+
self._degraded_warn_rate = degraded_warn_rate
193+
self._degraded_warn_window = degraded_warn_window
194+
self._degraded_warn_interval_s = degraded_warn_interval_s
195+
self._recent_retryable: deque[bool] = deque(maxlen=degraded_warn_window)
196+
self._last_degraded_warn_at: float = 0.0
197+
180198
# Pre-compute row-group sizes for O(1) lookup
181199
self._rg_size_map: dict[int, int] = dict(row_groups)
182200

@@ -606,6 +624,33 @@ def _check_error_rate(self, *, success: bool) -> None:
606624
if errors / self._shutdown_error_window >= self._shutdown_error_rate:
607625
self._early_shutdown = True
608626

627+
def _record_retryable_outcome(self, *, retryable: bool) -> None:
628+
"""Track retryable-error rate and emit a throttled WARN under provider degradation.
629+
630+
Distinct from ``_check_error_rate``: every outcome (success or failure)
631+
feeds this window so the rate reflects the provider's overall health, not
632+
just the error mix. Only retryable errors (rate-limit, timeout, 5xx,
633+
connection) count toward the rate; non-retryable failures register as 0.
634+
"""
635+
if self._degraded_warn_window <= 0:
636+
return
637+
self._recent_retryable.append(retryable)
638+
if len(self._recent_retryable) < self._degraded_warn_window:
639+
return
640+
rate = sum(self._recent_retryable) / self._degraded_warn_window
641+
if rate < self._degraded_warn_rate:
642+
return
643+
now = time.monotonic()
644+
if now - self._last_degraded_warn_at < self._degraded_warn_interval_s:
645+
return
646+
self._last_degraded_warn_at = now
647+
pct = int(round(rate * 100))
648+
logger.warning(
649+
f"Provider showing degraded performance: {pct}% of last {self._degraded_warn_window} "
650+
"task outcomes were retryable errors (rate-limit, timeout, 5xx, connection). "
651+
"Run may take longer than expected; salvage will retry these."
652+
)
653+
609654
async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
610655
"""Dispatch from_scratch tasks for a row group."""
611656
self._rg_states[rg_id].seeds_dispatched = True
@@ -730,6 +775,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
730775
self._tracker.mark_cell_complete(col, task.row_group, task.row_index)
731776

732777
self._check_error_rate(success=True)
778+
self._record_retryable_outcome(retryable=False)
733779
if self._reporter:
734780
if cell_skipped:
735781
self._reporter.record_skipped(task.column)
@@ -746,6 +792,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
746792
# and would otherwise trip the gate even when salvage could recover.
747793
if not retryable:
748794
self._check_error_rate(success=False)
795+
self._record_retryable_outcome(retryable=retryable)
749796
if not retryable and self._reporter:
750797
self._reporter.record_failure(task.column)
751798
if self._trace and trace:

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,119 @@ async def test_retryable_errors_do_not_trigger_early_shutdown(
919919
assert tracker.is_row_group_complete(0, 10, ["seed", "col"])
920920

921921

922+
@pytest.mark.asyncio(loop_scope="session")
923+
async def test_degraded_provider_warn_fires_above_threshold(caplog: pytest.LogCaptureFixture) -> None:
924+
"""When >= threshold of recent outcomes are retryable errors, a WARN log fires."""
925+
provider = _mock_provider()
926+
configs = [
927+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
928+
LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
929+
]
930+
strategies = {
931+
"seed": GenerationStrategy.FULL_COLUMN,
932+
"col": GenerationStrategy.CELL_BY_CELL,
933+
}
934+
# 6 retryable failures across 10 cells + their successful retries → ~6/16 retryable.
935+
# Set window to 8 and threshold to 0.5 so the WARN can fire.
936+
generators = {
937+
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
938+
"col": MockRetryableErrorGenerator(
939+
config=_expr_config("col"),
940+
resource_provider=provider,
941+
error_factory=lambda: ModelTimeoutError("read timeout"),
942+
retryable_failures=6,
943+
),
944+
}
945+
946+
graph = ExecutionGraph.create(configs, strategies)
947+
row_groups = [(0, 10)]
948+
tracker = CompletionTracker.with_graph(graph, row_groups)
949+
950+
storage = MagicMock()
951+
storage.dataset_name = "test"
952+
storage.get_file_paths.return_value = {}
953+
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
954+
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
955+
buffer_mgr = RowGroupBufferManager(storage)
956+
957+
scheduler = AsyncTaskScheduler(
958+
generators=generators,
959+
graph=graph,
960+
tracker=tracker,
961+
row_groups=row_groups,
962+
buffer_manager=buffer_mgr,
963+
degraded_warn_rate=0.5,
964+
degraded_warn_window=8,
965+
degraded_warn_interval_s=0.0,
966+
)
967+
with caplog.at_level("WARNING"):
968+
await scheduler.run()
969+
970+
degraded_msgs = [r for r in caplog.records if "degraded performance" in r.getMessage()]
971+
assert degraded_msgs, "expected a 'degraded performance' WARN to be emitted"
972+
973+
974+
@pytest.mark.asyncio(loop_scope="session")
975+
async def test_degraded_provider_warn_throttled(caplog: pytest.LogCaptureFixture) -> None:
976+
"""Successive degraded windows within the throttle interval emit only one WARN."""
977+
provider = _mock_provider()
978+
configs = [
979+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
980+
LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
981+
]
982+
strategies = {
983+
"seed": GenerationStrategy.FULL_COLUMN,
984+
"col": GenerationStrategy.CELL_BY_CELL,
985+
}
986+
generators = {
987+
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
988+
"col": MockRetryableErrorGenerator(
989+
config=_expr_config("col"),
990+
resource_provider=provider,
991+
error_factory=lambda: ModelTimeoutError("read timeout"),
992+
retryable_failures=8,
993+
),
994+
}
995+
996+
graph = ExecutionGraph.create(configs, strategies)
997+
row_groups = [(0, 12)]
998+
tracker = CompletionTracker.with_graph(graph, row_groups)
999+
1000+
storage = MagicMock()
1001+
storage.dataset_name = "test"
1002+
storage.get_file_paths.return_value = {}
1003+
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
1004+
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
1005+
buffer_mgr = RowGroupBufferManager(storage)
1006+
1007+
scheduler = AsyncTaskScheduler(
1008+
generators=generators,
1009+
graph=graph,
1010+
tracker=tracker,
1011+
row_groups=row_groups,
1012+
buffer_manager=buffer_mgr,
1013+
degraded_warn_rate=0.5,
1014+
degraded_warn_window=4,
1015+
degraded_warn_interval_s=3600.0,
1016+
)
1017+
with caplog.at_level("WARNING"):
1018+
await scheduler.run()
1019+
1020+
degraded_msgs = [r for r in caplog.records if "degraded performance" in r.getMessage()]
1021+
assert len(degraded_msgs) == 1, f"expected exactly one throttled WARN, got {len(degraded_msgs)}"
1022+
1023+
1024+
@pytest.mark.asyncio(loop_scope="session")
1025+
async def test_degraded_provider_warn_silent_under_threshold(caplog: pytest.LogCaptureFixture) -> None:
1026+
"""Healthy runs (no errors) never emit the degraded-provider WARN."""
1027+
scheduler, _tracker = _build_simple_pipeline(num_records=5)
1028+
with caplog.at_level("WARNING"):
1029+
await scheduler.run()
1030+
1031+
degraded_msgs = [r for r in caplog.records if "degraded performance" in r.getMessage()]
1032+
assert not degraded_msgs
1033+
1034+
9221035
@pytest.mark.asyncio(loop_scope="session")
9231036
async def test_scheduler_on_before_checkpoint_callback() -> None:
9241037
"""on_before_checkpoint is called before each row group is checkpointed."""

0 commit comments

Comments
 (0)