diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index 14318791f..87be65548 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -15,6 +15,7 @@ from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy from data_designer.engine.column_generators.generators.base import SYNC_BRIDGE_TIMEOUT, ColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError +from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError from data_designer.logging import LOG_INDENT if TYPE_CHECKING: @@ -68,8 +69,13 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: return future.result(timeout=SYNC_BRIDGE_TIMEOUT) except concurrent.futures.TimeoutError as exc: future.cancel() - logger.warning("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT) - raise TimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc + # Demoted to debug: the raised ModelTimeoutError already surfaces + # the timeout at the scheduler with full context, and the throttled + # degraded-provider WARN is the user-facing signal under sustained + # bridge timeouts. Per-event WARN was noise on top of those. + logger.debug("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT) + # Raise as ModelTimeoutError so the scheduler classifies it retryable. + raise ModelTimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc def __getattr__(self, name: str) -> Any: return getattr(object.__getattribute__(self, "_facade"), name) @@ -147,6 +153,10 @@ async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | li result = await self._ainvoke_generator_function(data) except CustomColumnGenerationError: raise + except RETRYABLE_MODEL_ERRORS: + # Preserve retryability so the scheduler can salvage these + # instead of counting them toward the early-shutdown gate. + raise except Exception as e: logger.warning( f"⚠️ Custom generator function {self.config.generator_function.__name__!r} " @@ -193,6 +203,10 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. result = self._invoke_generator_function(data) except CustomColumnGenerationError: raise + except RETRYABLE_MODEL_ERRORS: + # Preserve retryability so the scheduler can salvage these + # instead of counting them toward the early-shutdown gate. + raise except Exception as e: if not is_dataframe: logger.warning( diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index c62ff60d9..2120589b8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -30,12 +30,7 @@ ) from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace -from data_designer.engine.models.errors import ( - ModelAPIConnectionError, - ModelInternalServerError, - ModelRateLimitError, - ModelTimeoutError, -) +from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS if TYPE_CHECKING: from data_designer.engine.column_generators.generators.base import ColumnGenerator @@ -47,12 +42,13 @@ DEFAULT_TASK_POOL_SIZE: int = 256 LLM_WAIT_POOL_MULTIPLIER: int = 2 -_RETRYABLE_MODEL_ERRORS = ( - ModelRateLimitError, - ModelTimeoutError, - ModelInternalServerError, - ModelAPIConnectionError, -) +# Degraded-provider WARN: emit at most one warning per interval when the +# rolling fraction of retryable errors exceeds the threshold. Distinct from +# the early-shutdown gate (which fires on non-retryable errors). +# TODO: thread these through RunConfig so users can tune them per run. +DEGRADED_WARN_RATE: float = 0.5 +DEGRADED_WARN_WINDOW: int = 20 +DEGRADED_WARN_INTERVAL_S: float = 60.0 class TrackingSemaphore(asyncio.Semaphore): @@ -105,6 +101,9 @@ def __init__( shutdown_error_rate: float = 0.5, shutdown_error_window: int = 10, disable_early_shutdown: bool = False, + degraded_warn_rate: float = DEGRADED_WARN_RATE, + degraded_warn_window: int = DEGRADED_WARN_WINDOW, + degraded_warn_interval_s: float = DEGRADED_WARN_INTERVAL_S, trace: bool = False, num_records: int = 0, buffer_size: int = 0, @@ -177,6 +176,22 @@ def __init__( self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window) self._all_rgs_admitted = False + # Degraded-provider WARN: separate window tracking retryable-vs-not for + # every outcome (success or failure), throttled to one log per interval. + self._degraded_warn_rate = degraded_warn_rate + self._degraded_warn_window = degraded_warn_window + self._degraded_warn_interval_s = degraded_warn_interval_s + self._recent_retryable: deque[bool] = deque(maxlen=degraded_warn_window) + # Initialize to -inf so the first WARN is always emitted regardless of + # the monotonic clock's absolute value (which can be near-zero on freshly + # booted CI runners). + self._last_degraded_warn_at: float = float("-inf") + + # Row groups that were partially salvaged after early shutdown + # (i.e., some rows complete, some incomplete-then-dropped). Surfaced + # via the partial_row_groups property as a structured signal. + self._partial_row_groups: list[int] = [] + # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) @@ -221,6 +236,20 @@ def _setup_async_progress_reporter( def active_worker_count(self) -> int: return sum(1 for t in self._worker_tasks if not t.done()) + @property + def early_shutdown(self) -> bool: + """True if the run terminated via the early-shutdown gate.""" + return self._early_shutdown + + @property + def partial_row_groups(self) -> tuple[int, ...]: + """Row group ids that were partially salvaged after early shutdown. + + Empty unless ``early_shutdown`` is True. Each id had some rows + complete and the rest dropped before checkpointing. + """ + return tuple(self._partial_row_groups) + def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: """Create a tracked worker task that auto-removes itself on completion.""" task = asyncio.create_task(coro) @@ -270,13 +299,9 @@ async def run(self) -> None: # Launch admission as a background task so it interleaves with dispatch. admission_task = asyncio.create_task(self._admit_row_groups()) - dispatch_error: BaseException | None = None try: # Main dispatch loop await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns) - except BaseException as exc: - dispatch_error = exc - raise finally: # Always cancel admission + drain in-flight workers, regardless # of how the dispatch loop exited (normal, early shutdown, @@ -286,11 +311,18 @@ async def run(self) -> None: with contextlib.suppress(asyncio.CancelledError): await admission_task await asyncio.shield(self._cancel_workers()) - + # Salvage partially-complete row groups left over from early + # shutdown. Must run AFTER _cancel_workers - in-flight tasks + # could otherwise write into a buffer that's being finalized. + if self._early_shutdown and self._rg_states: + self._finalize_after_shutdown(all_columns) + + # Reached only on the clean-exit path; an exception in the + # dispatch loop or the finally block propagates and skips this. if self._reporter: self._reporter.log_final() - if self._rg_states and dispatch_error is None: + if self._rg_states: incomplete = list(self._rg_states) logger.error( f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " @@ -306,7 +338,7 @@ async def _main_dispatch_loop( """Core dispatch loop extracted from ``run()``.""" while True: if self._early_shutdown: - logger.warning("Early shutdown triggered - error rate exceeded threshold") + logger.warning("Early shutdown triggered - non-retryable error rate exceeded threshold") if self._deferred: await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) @@ -534,6 +566,44 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: checkpointed = {rg_id for rg_id, _ in completed} self._deferred = [t for t in self._deferred if t.row_group not in checkpointed] + def _finalize_after_shutdown(self, all_columns: list[str]) -> None: + """Salvage row groups left in flight when early shutdown fired. + + For each remaining row group, drop rows that aren't fully complete + (and weren't already dropped); after that, ``is_row_group_complete`` + is true by construction over the surviving rows, so delegating to + ``_checkpoint_completed_row_groups`` writes survivors and frees + zero-survivor groups via the buffer manager's existing logic. + + Note on processors: ``_checkpoint_completed_row_groups`` calls + ``on_before_checkpoint`` (post-batch) but never ``on_seeds_complete`` + (pre-batch). If the gate fires before seeds completed for a row + group, that row group's pre-batch processor never ran. Survivors + are checkpointed without it. This is the existing contract for + partial-row-group salvage. + """ + for rg_id in list(self._rg_states.keys()): + rg_size = self._rg_states[rg_id].size + had_incomplete = False + for ri in range(rg_size): + if self._tracker.is_dropped(rg_id, ri): + continue + if all( + self._tracker.is_complete(SliceRef(column=col, row_group=rg_id, row_index=ri)) + for col in all_columns + ): + continue + had_incomplete = True + self._drop_row(rg_id, ri) + if had_incomplete: + survivors = sum(1 for ri in range(rg_size) if not self._tracker.is_dropped(rg_id, ri)) + if survivors > 0: + self._partial_row_groups.append(rg_id) + logger.warning(f"Row group {rg_id}: salvaging {survivors} of {rg_size} rows after early shutdown.") + else: + logger.warning(f"Row group {rg_id}: 0 of {rg_size} rows survived early shutdown - skipping write.") + self._checkpoint_completed_row_groups(all_columns) + def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: """Run pre-batch callbacks for row groups whose seeds just completed.""" for rg_id, state in list(self._rg_states.items()): @@ -606,6 +676,35 @@ def _check_error_rate(self, *, success: bool) -> None: if errors / self._shutdown_error_window >= self._shutdown_error_rate: self._early_shutdown = True + def _record_retryable_outcome(self, *, retryable: bool) -> None: + """Track retryable-error rate and emit a throttled WARN under provider degradation. + + Distinct from ``_check_error_rate``: every LLM-bound task outcome (success + or failure) feeds this window so the rate reflects the provider's overall + health, not just the error mix. The call site filters on ``is_llm`` so + non-LLM tasks (samplers, expressions, non-LLM customs) don't dilute the + rate. Only retryable errors (rate-limit, timeout, 5xx, connection) count + toward the rate; non-retryable failures register as 0. + """ + if self._degraded_warn_window <= 0: + return + self._recent_retryable.append(retryable) + if len(self._recent_retryable) < self._degraded_warn_window: + return + rate = sum(self._recent_retryable) / self._degraded_warn_window + if rate < self._degraded_warn_rate: + return + now = time.monotonic() + if now - self._last_degraded_warn_at < self._degraded_warn_interval_s: + return + self._last_degraded_warn_at = now + pct = int(round(rate * 100)) + logger.warning( + f"Provider showing degraded performance: {pct}% of last {self._degraded_warn_window} " + "task outcomes were retryable errors (rate-limit, timeout, 5xx, connection). " + "Run may take longer than expected; salvage will retry these." + ) + async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" self._rg_states[rg_id].seeds_dispatched = True @@ -730,6 +829,12 @@ async def _execute_task_inner_impl(self, task: Task) -> None: self._tracker.mark_cell_complete(col, task.row_group, task.row_index) self._check_error_rate(success=True) + # The degraded-provider WARN is provider-scoped: only feed the + # window from LLM-bound tasks so a healthy non-model task mix + # (samplers, expressions, non-LLM customs) doesn't dilute the + # rate and silence the WARN under genuine provider stress. + if is_llm: + self._record_retryable_outcome(retryable=False) if self._reporter: if cell_skipped: self._reporter.record_skipped(task.column) @@ -739,9 +844,15 @@ async def _execute_task_inner_impl(self, task: Task) -> None: trace.status = "ok" except Exception as exc: - if not isinstance(exc, ModelRateLimitError): - self._check_error_rate(success=False) retryable = self._is_retryable(exc) + # Only non-retryable errors (auth, schema, code bugs) count toward + # the early-shutdown gate. Retryable errors (rate-limit, timeout, + # transient 5xx, connection blips) cluster under provider degradation + # and would otherwise trip the gate even when salvage could recover. + if not retryable: + self._check_error_rate(success=False) + if is_llm: + self._record_retryable_outcome(retryable=retryable) if not retryable and self._reporter: self._reporter.record_failure(task.column) if self._trace and trace: @@ -930,7 +1041,7 @@ def get_semaphore_permits(self) -> tuple[int, int]: @staticmethod def _is_retryable(exc: Exception) -> bool: """Classify whether an exception is retryable.""" - return isinstance(exc, _RETRYABLE_MODEL_ERRORS) + return isinstance(exc, RETRYABLE_MODEL_ERRORS) def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index c76be0980..96470977c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -113,6 +113,16 @@ def __init__( self._registry = registry or DataDesignerRegistry() self._graph: ExecutionGraph | None = None self._use_async: bool = DATA_DESIGNER_ASYNC_ENGINE + # Structured signal: set by _build_async if the scheduler hit early shutdown. + # Stays at defaults for sync-engine and successful async runs. Reset at + # the start of each public run path so reused builder instances don't + # leak state across runs. + self._early_shutdown: bool = False + self._partial_row_groups: tuple[int, ...] = () + # Number of records actually written by the most recent async run. + # ``-1`` means "no async run has executed yet" so callers can + # distinguish "0 records produced" from "never ran". + self._actual_num_records: int = -1 self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) self._column_configs = compile_dataset_builder_column_configs(self._data_designer_config) @@ -135,6 +145,21 @@ def processors(self) -> tuple[Processor, ...]: def task_traces(self) -> list[TaskTrace]: return self._task_traces + @property + def early_shutdown(self) -> bool: + """True if the most recent async run terminated via the early-shutdown gate.""" + return self._early_shutdown + + @property + def partial_row_groups(self) -> tuple[int, ...]: + """Row group ids that were partially salvaged after early shutdown (most recent run).""" + return self._partial_row_groups + + @property + def actual_num_records(self) -> int: + """Records actually written by the most recent async run (-1 if no run yet).""" + return self._actual_num_records + def set_processor_runner(self, processors: list[Processor]) -> None: """Replace the processor runner with a new one using the given processors.""" self._processor_runner = ProcessorRunner( @@ -179,6 +204,7 @@ def build( Returns: Path to the generated dataset directory. """ + self._reset_run_state() self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() self._write_builder_config() @@ -215,6 +241,7 @@ def build( return self.artifact_storage.final_dataset_path def build_preview(self, *, num_records: int) -> pd.DataFrame: + self._reset_run_state() self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() @@ -239,6 +266,13 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: return dataset + def _reset_run_state(self) -> None: + """Clear per-run signals so reused builder instances don't leak state across runs.""" + self._early_shutdown = False + self._partial_row_groups = () + self._actual_num_records = -1 + self._task_traces = [] + def _build_async_preview(self, generators: list[ColumnGenerator], num_records: int) -> pd.DataFrame: """Async preview path - single row group, no disk writes, returns in-memory DataFrame.""" logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue preview") @@ -256,9 +290,13 @@ def _build_async_preview(self, generators: list[ColumnGenerator], num_records: i loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) - future.result() - - self._task_traces = scheduler.traces + try: + future.result() + finally: + self._task_traces = scheduler.traces + self._early_shutdown = scheduler.early_shutdown + self._partial_row_groups = scheduler.partial_row_groups + self._actual_num_records = buffer_manager.actual_num_records if not buffer_manager.has_row_group(0): return lazy.pd.DataFrame() @@ -320,12 +358,19 @@ def on_complete(final_path: Path | str | None) -> None: group_id = uuid.uuid4().hex pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() - # Run on background event loop + # Run on background event loop. Capture scheduler state in `finally` + # so the structured signal is preserved even if `scheduler.run()` + # raises during the salvage path - otherwise callers see a generic + # error and lose the early-shutdown context. loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) - future.result() - - self._task_traces = scheduler.traces + try: + future.result() + finally: + self._task_traces = scheduler.traces + self._early_shutdown = scheduler.early_shutdown + self._partial_row_groups = scheduler.partial_row_groups + self._actual_num_records = buffer_manager.actual_num_records # Emit telemetry try: @@ -338,13 +383,21 @@ def on_complete(final_path: Path | str | None) -> None: buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) # Surface partial completion - actual = buffer_manager.actual_num_records + actual = self._actual_num_records if actual < num_records: pct = actual / num_records * 100 if num_records > 0 else 0 - logger.warning( - f"⚠️ Generated {actual} of {num_records} requested records ({pct:.0f}%). " - "The dataset may be incomplete due to errors or early shutdown." - ) + base = f"⚠️ Generated {actual} of {num_records} requested records ({pct:.0f}%). " + if scheduler.early_shutdown: + partial = scheduler.partial_row_groups + detail = ( + f"Early shutdown was triggered (non-retryable error rate exceeded threshold); " + f"{len(partial)} row group(s) salvaged with partial rows." + if partial + else "Early shutdown was triggered (non-retryable error rate exceeded threshold)." + ) + logger.warning(base + detail) + else: + logger.warning(base + "The dataset may be incomplete due to dropped rows.") def _prepare_async_run( self, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py index a9a16f60e..e8f720c0c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py @@ -310,6 +310,15 @@ def release_failure( with self._lock: state = self._get_or_create_domain(provider_name, model_id, domain) state.in_flight = max(0, state.in_flight - 1) + # Non-rate-limit failure breaks the 429 cascade: a sequence like + # 429 → 500 → 429 should treat the second 429 as the start of a + # new cascade. But only after the prior burst has fully drained + # (in_flight == 0) - otherwise mixed responses from a single + # in-flight wave (429 → 500 → 429 with concurrent slots) would + # double-reduce the limit even though the provider hasn't + # recovered between the two 429s. + if state.in_flight == 0: + state.consecutive_429s = 0 # ------------------------------------------------------------------- # Sync / async wrappers diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 95469f8bd..dc054cff2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -131,6 +131,17 @@ def __init__( class ImageGenerationError(DataDesignerError): ... +# Errors that the async scheduler defers to salvage instead of failing the run. +# Callers that wrap arbitrary exceptions (e.g. CustomColumnGenerator) should +# re-raise these unchanged so retryability is preserved through the wrap. +RETRYABLE_MODEL_ERRORS: tuple[type[Exception], ...] = ( + ModelRateLimitError, + ModelTimeoutError, + ModelInternalServerError, + ModelAPIConnectionError, +) + + class FormattedLLMErrorMessage(BaseModel): cause: str solution: str diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index 2d73fc8dd..9aa0afd5e 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -5,8 +5,10 @@ from __future__ import annotations +import asyncio +import threading from typing import TYPE_CHECKING, Any -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from pydantic import BaseModel, ValidationError @@ -18,8 +20,10 @@ from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy from data_designer.config.custom_column import custom_column_generator -from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator +from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator, _AsyncBridgedModelFacade from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError +from data_designer.engine.models.clients.errors import SyncClientUnavailableError +from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError from data_designer.engine.resources.resource_provider import ResourceProvider @@ -350,6 +354,37 @@ def failing_generator(row: dict) -> dict: assert "something broke" in caplog.text +@pytest.mark.parametrize("exc_cls", RETRYABLE_MODEL_ERRORS, ids=lambda c: c.__name__) +def test_retryable_model_errors_pass_through_sync_wrap(exc_cls: type[Exception]) -> None: + """Retryable model errors raised inside a sync generator must NOT be wrapped. + + Without this, the scheduler classifies the wrapped error as non-retryable and + counts it toward the early-shutdown gate (regression seen in #575 follow-up). + """ + + @custom_column_generator() + def raising_gen(row: dict) -> dict: + raise exc_cls("boom") + + generator = _create_test_generator(name="result", generator_function=raising_gen) + with pytest.raises(exc_cls): + generator.generate({"input": 1}) + + +@pytest.mark.parametrize("exc_cls", RETRYABLE_MODEL_ERRORS, ids=lambda c: c.__name__) +@pytest.mark.asyncio +async def test_retryable_model_errors_pass_through_async_wrap(exc_cls: type[Exception]) -> None: + """Retryable errors raised inside an async user generator must propagate unchanged.""" + + @custom_column_generator() + async def raising_gen(row: dict) -> dict: + raise exc_cls("boom") + + generator = _create_test_generator(name="result", generator_function=raising_gen) + with pytest.raises(exc_cls): + await generator.agenerate({"input": 1}) + + def test_undeclared_columns_removed_with_warning(caplog: pytest.LogCaptureFixture) -> None: """Test that undeclared columns are removed with a warning.""" import logging @@ -469,107 +504,125 @@ def df_func(df: pd.DataFrame) -> pd.DataFrame: gen.generate({"input": 1}) -# Async model bridge tests +# Async model bridge tests for _AsyncBridgedModelFacade + + +def test_async_bridge_proxy_transparent_in_sync_mode(stub_resource_provider, stub_model_facade) -> None: + """Proxy passes through generate(), forwards attributes; _build_models_dict returns raw facades.""" + @custom_column_generator(required_columns=["input"], model_aliases=["test-model"]) + def gen_with_model(row: dict, generator_params: SampleParams, models: dict) -> dict: + row["result"] = "ok" + return row -class TestAsyncBridgedModelFacade: - """Tests for _AsyncBridgedModelFacade proxy used by custom columns with model access.""" + generator = _create_test_generator( + name="result", + generator_function=gen_with_model, + generator_params=SampleParams(), + resource_provider=stub_resource_provider, + ) - def test_proxy_transparent_in_sync_mode(self, stub_resource_provider, stub_model_facade) -> None: - """Proxy passes through generate(), forwards attributes; _build_models_dict returns raw facades.""" - from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade + # _build_models_dict returns raw facades (wrapping happens at the call site) + models = generator._build_models_dict() + assert not isinstance(models["test-model"], _AsyncBridgedModelFacade) - @custom_column_generator(required_columns=["input"], model_aliases=["test-model"]) - def gen_with_model(row: dict, generator_params: SampleParams, models: dict) -> dict: - row["result"] = "ok" - return row + # Proxy itself passes through generate() and forwards attributes + proxy = _AsyncBridgedModelFacade(stub_model_facade) + result, _ = proxy.generate("test", parser=str) + assert result == "Generated summary text" + stub_model_facade.generate.assert_called_once_with("test", parser=str) + assert proxy.model_alias == "test_model" - generator = _create_test_generator( - name="result", - generator_function=gen_with_model, - generator_params=SampleParams(), - resource_provider=stub_resource_provider, - ) - # _build_models_dict returns raw facades (wrapping happens at the call site) - models = generator._build_models_dict() - assert not isinstance(models["test-model"], _AsyncBridgedModelFacade) - - # Proxy itself passes through generate() and forwards attributes - proxy = _AsyncBridgedModelFacade(stub_model_facade) - result, _ = proxy.generate("test", parser=str) - assert result == "Generated summary text" - stub_model_facade.generate.assert_called_once_with("test", parser=str) - assert proxy.model_alias == "test_model" - - def test_bridges_to_agenerate_on_sync_client_error(self) -> None: - """When sync generate() fails with an async/sync error, falls back to agenerate().""" - import asyncio - import threading - from unittest.mock import patch - - from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade - from data_designer.engine.models.clients.errors import SyncClientUnavailableError - - facade = Mock() - facade.generate.side_effect = SyncClientUnavailableError( - "Sync methods are not available on an async-mode HttpModelClient." - ) +def test_async_bridge_falls_back_to_agenerate_on_sync_client_error() -> None: + """When sync generate() fails with an async/sync error, falls back to agenerate().""" + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) - async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple: - return ("async_result", list(args), kwargs) + async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple: + return ("async_result", list(args), kwargs) + + facade.agenerate = fake_agenerate + proxy = _AsyncBridgedModelFacade(facade) + + engine_loop = asyncio.new_event_loop() + engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True) + engine_thread.start() + + try: + with patch( + "data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop", + return_value=engine_loop, + ): + result = proxy.generate("hello", parser=str) + assert result == ("async_result", ["hello"], {"parser": str}) + finally: + engine_loop.call_soon_threadsafe(engine_loop.stop) + engine_thread.join(timeout=5) + + +def test_async_bridge_non_client_mode_errors_propagate() -> None: + """Only SyncClientUnavailableError triggers bridging; other errors propagate.""" + # ValueError - different type entirely + facade = Mock() + facade.generate.side_effect = ValueError("invalid prompt format") + proxy = _AsyncBridgedModelFacade(facade) + with pytest.raises(ValueError, match="invalid prompt format"): + proxy.generate(prompt="hello") + + # RuntimeError - same base type as SyncClientUnavailableError, but not caught + facade = Mock() + facade.generate.side_effect = RuntimeError("connection timed out for async request") + proxy = _AsyncBridgedModelFacade(facade) + with pytest.raises(RuntimeError, match="connection timed out"): + proxy.generate(prompt="hello") + + +def test_async_bridge_timeout_raises_model_timeout_error() -> None: + """A bridge timeout must surface as ModelTimeoutError so the scheduler sees it as retryable.""" + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) - facade.agenerate = fake_agenerate - proxy = _AsyncBridgedModelFacade(facade) + async def hangs_forever(*args: Any, **kwargs: Any) -> tuple: + await asyncio.sleep(60) + return ("never", [], {}) - engine_loop = asyncio.new_event_loop() - engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True) - engine_thread.start() + facade.agenerate = hangs_forever + proxy = _AsyncBridgedModelFacade(facade) - try: - with patch( + engine_loop = asyncio.new_event_loop() + engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True) + engine_thread.start() + + try: + with ( + patch( "data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop", return_value=engine_loop, - ): - result = proxy.generate("hello", parser=str) - assert result == ("async_result", ["hello"], {"parser": str}) - finally: - engine_loop.call_soon_threadsafe(engine_loop.stop) - engine_thread.join(timeout=5) - - def test_non_client_mode_errors_propagate(self) -> None: - """Only SyncClientUnavailableError triggers bridging; other errors propagate.""" - from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade - - # ValueError - different type entirely - facade = Mock() - facade.generate.side_effect = ValueError("invalid prompt format") - proxy = _AsyncBridgedModelFacade(facade) - with pytest.raises(ValueError, match="invalid prompt format"): - proxy.generate(prompt="hello") - - # RuntimeError - same base type as SyncClientUnavailableError, but not caught - facade = Mock() - facade.generate.side_effect = RuntimeError("connection timed out for async request") - proxy = _AsyncBridgedModelFacade(facade) - with pytest.raises(RuntimeError, match="connection timed out"): - proxy.generate(prompt="hello") - - def test_deadlock_guard_on_event_loop(self) -> None: - """Raises a clear error instead of deadlocking when called from the event loop.""" - import asyncio - - from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade - from data_designer.engine.models.clients.errors import SyncClientUnavailableError - - facade = Mock() - facade.generate.side_effect = SyncClientUnavailableError( - "Sync methods are not available on an async-mode HttpModelClient." - ) - proxy = _AsyncBridgedModelFacade(facade) + ), + patch("data_designer.engine.column_generators.generators.custom.SYNC_BRIDGE_TIMEOUT", 0.05), + pytest.raises(ModelTimeoutError, match="bridge timed out"), + ): + proxy.generate("hello") + finally: + engine_loop.call_soon_threadsafe(engine_loop.stop) + engine_thread.join(timeout=5) + + +def test_async_bridge_deadlock_guard_on_event_loop() -> None: + """Raises a clear error instead of deadlocking when called from the event loop.""" + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) + proxy = _AsyncBridgedModelFacade(facade) - async def call_from_loop() -> None: - proxy.generate(prompt="hello") + async def call_from_loop() -> None: + proxy.generate(prompt="hello") - with pytest.raises(RuntimeError, match="Use 'await model.agenerate\\(\\)'"): - asyncio.run(call_from_loop()) + with pytest.raises(RuntimeError, match="Use 'await model.agenerate\\(\\)'"): + asyncio.run(call_from_loop()) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 0c6ec4e4d..fe536957c 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable from typing import Any from unittest.mock import MagicMock @@ -31,7 +32,12 @@ from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager -from data_designer.engine.models.errors import ModelInternalServerError, ModelRateLimitError +from data_designer.engine.models.errors import ( + RETRYABLE_MODEL_ERRORS, + ModelInternalServerError, + ModelRateLimitError, + ModelTimeoutError, +) from data_designer.engine.resources.resource_provider import ResourceProvider MODEL_ALIAS = "stub" @@ -167,6 +173,90 @@ def generate(self, data: dict) -> dict: return data +class MockSelectiveFailGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Cell generator with deterministic per-seed behavior. + + - Seeds in ``fail_on_seeds``: raise a non-retryable ``ValueError`` immediately. + - Seeds in ``slow_seeds``: block on ``asyncio.sleep`` so they remain + in-flight when the early-shutdown gate fires. + - All others: succeed. + + Cell-by-cell only — exercised through ``agenerate`` from the async scheduler. + """ + + def __init__( + self, + *args: Any, + fail_on_seeds: set[int] = frozenset(), + slow_seeds: set[int] = frozenset(), + slow_timeout_s: float = 5.0, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._fail = set(fail_on_seeds) + self._slow = set(slow_seeds) + self._slow_timeout_s = slow_timeout_s + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + async def agenerate(self, data: dict) -> dict: + seed = data.get("seed") + if seed in self._fail: + raise ValueError(f"non-retryable on seed={seed}") + if seed in self._slow: + await asyncio.sleep(self._slow_timeout_s) + data[self.config.name] = f"ok_{seed}" + return data + + def generate(self, data: dict) -> dict: + # Sync path: kept minimal because this mock is exercised exclusively + # through ``agenerate`` from the async scheduler. ``slow_seeds`` is + # intentionally not honored here — callers needing sync slow behavior + # should use a different fixture. + seed = data.get("seed") + if seed in self._fail: + raise ValueError(f"non-retryable on seed={seed}") + data[self.config.name] = f"ok_{seed}" + return data + + +class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Generator that raises a parametrizable retryable error then succeeds. + + Declares ``is_llm_bound=True`` because it mimics model-call behavior; + the scheduler's degraded-provider WARN window only counts LLM-bound tasks. + """ + + def __init__( + self, + *args: Any, + error_factory: Callable[[], Exception], + retryable_failures: int = 0, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._error_factory = error_factory + self._retryable_failures = retryable_failures + self._calls = 0 + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + @property + def is_llm_bound(self) -> bool: + return True + + def generate(self, data: dict) -> dict: + self._calls += 1 + if self._calls <= self._retryable_failures: + raise self._error_factory() + data[self.config.name] = f"ok_{data.get('seed', '?')}" + return data + + # -- Helper to build graph + scheduler ---------------------------------------- @@ -219,6 +309,51 @@ def _build_simple_pipeline( return scheduler, tracker +def _make_storage() -> MagicMock: + """Standard mock storage for buffer-manager-backed scheduler tests.""" + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + return storage + + +def _seed_plus_cell_setup( + cell_generator: ColumnGenerator, + num_records: int, +) -> tuple[ + dict[str, ColumnGenerator], + ExecutionGraph, + list[tuple[int, int]], + CompletionTracker, + RowGroupBufferManager, + MagicMock, +]: + """Build the shared seed → LLM cell pipeline scaffolding (no scheduler yet). + + Used by early-shutdown / WARN tests that need a real ``buffer_manager`` + *before* constructing the scheduler (e.g. to wire a checkpoint callback + that closes over it). + """ + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN, "cell_out": GenerationStrategy.CELL_BY_CELL} + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": cell_generator, + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, num_records)] + tracker = CompletionTracker.with_graph(graph, row_groups) + storage = _make_storage() + buffer_manager = RowGroupBufferManager(storage) + return generators, graph, row_groups, tracker, buffer_manager, storage + + # -- Tests -------------------------------------------------------------------- @@ -689,6 +824,89 @@ async def test_scheduler_error_rate_shutdown() -> None: # Early shutdown: not all rows should be checkpointed (some row groups incomplete) assert buffer_mgr.actual_num_records < 10 + # No leftover unfinished row groups (finalize-after-shutdown drains them). + assert not scheduler._rg_states + + +@pytest.mark.asyncio(loop_scope="session") +async def test_partial_row_group_salvaged_after_early_shutdown() -> None: + """Mid-run shutdown drops incomplete rows and checkpoints survivors.""" + # 3 succeed (0,1,2), 3 fail non-retryable (5,6,7), 4 stay in-flight (3,4,8,9) + # until cancellation. Window=4, rate=0.5 → gate trips after ~3-5 outcomes. + cell = MockSelectiveFailGenerator( + config=_expr_config("cell_out"), + resource_provider=_mock_provider(), + fail_on_seeds={5, 6, 7}, + slow_seeds={3, 4, 8, 9}, + ) + generators, graph, row_groups, tracker, buffer_mgr, _storage = _seed_plus_cell_setup(cell, num_records=10) + finalized: list[int] = [] + + def on_finalize(rg_id: int) -> None: + buffer_mgr.checkpoint_row_group(rg_id) + finalized.append(rg_id) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_finalize_row_group=on_finalize, + shutdown_error_rate=0.5, + shutdown_error_window=4, + ) + await scheduler.run() + + assert scheduler.early_shutdown + # Survivor count depends on event-loop dispatch ordering between fast/fail/slow + # seeds, so the assertion is bounded rather than exact: 3 fail → at least 3 + # dropped, so survivors ≤ 7; at least 1 success is needed for the gate to + # start counting. The point of the test is "salvage works", not exact counts. + assert 0 in finalized + assert scheduler.partial_row_groups == (0,) + assert 1 <= buffer_mgr.actual_num_records <= 7 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_zero_survivor_shutdown_does_not_raise() -> None: + """If every row is dropped at shutdown, the row group is freed without writing parquet. + + Also covers the healthy-run baseline: ``partial_row_groups`` stays empty + when no rows survived (all dropped, none salvaged). + """ + cell = MockSelectiveFailGenerator( + config=_expr_config("cell_out"), + resource_provider=_mock_provider(), + fail_on_seeds=set(range(5)), + ) + generators, graph, row_groups, tracker, buffer_mgr, storage = _seed_plus_cell_setup(cell, num_records=5) + finalized: list[int] = [] + + def on_finalize(rg_id: int) -> None: + buffer_mgr.checkpoint_row_group(rg_id) + finalized.append(rg_id) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_finalize_row_group=on_finalize, + shutdown_error_rate=0.5, + shutdown_error_window=2, + ) + # Must not raise (no FileNotFoundError, no DataDesignerGenerationError). + await scheduler.run() + + assert scheduler.early_shutdown + assert buffer_mgr.actual_num_records == 0 + # All rows dropped → checkpoint path frees buffer without writing; on_finalize + # is *not* called because every row was dropped before survivors could exist. + assert finalized == [] + assert scheduler.partial_row_groups == () + storage.write_batch_to_parquet_file.assert_not_called() @pytest.mark.asyncio(loop_scope="session") @@ -822,6 +1040,127 @@ async def test_rate_limit_errors_do_not_trigger_early_shutdown() -> None: assert tracker.is_row_group_complete(0, 10, ["seed", "col"]) +@pytest.mark.parametrize("exc_cls", RETRYABLE_MODEL_ERRORS, ids=lambda c: c.__name__) +@pytest.mark.asyncio(loop_scope="session") +async def test_retryable_errors_do_not_trigger_early_shutdown( + exc_cls: type[Exception], +) -> None: + """All retryable errors (rate-limit, timeout, 5xx, connection) bypass the early-shutdown gate. + + Regression test for #575: clustered ``ModelTimeoutError`` during provider degradation + used to trip the gate even though salvage could recover the rows. + """ + cell = MockRetryableErrorGenerator( + config=_expr_config("cell_out"), + resource_provider=_mock_provider(), + error_factory=lambda: exc_cls("boom"), + retryable_failures=8, + ) + generators, graph, row_groups, tracker, buffer_mgr, _storage = _seed_plus_cell_setup(cell, num_records=10) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + shutdown_error_rate=0.5, + shutdown_error_window=10, + ) + await scheduler.run() + + assert not scheduler._early_shutdown + assert scheduler._recent_outcomes.count(False) == 0 + assert tracker.is_row_group_complete(0, 10, ["seed", "cell_out"]) + + +def _count_degraded_msgs(caplog: pytest.LogCaptureFixture) -> int: + return sum(1 for r in caplog.records if "degraded performance" in r.getMessage()) + + +@pytest.mark.parametrize( + "retryable_failures,num_records,window,interval_s,expected_count", + [ + # Above-threshold + zero throttle: at least one WARN should fire. + pytest.param(6, 10, 8, 0.0, "at_least_one", id="fires_above_threshold"), + # Above-threshold + 1h throttle: only one WARN despite sustained degradation. + pytest.param(8, 12, 4, 3600.0, 1, id="throttled_to_one"), + ], +) +@pytest.mark.asyncio(loop_scope="session") +async def test_degraded_provider_warn_emission( + caplog: pytest.LogCaptureFixture, + retryable_failures: int, + num_records: int, + window: int, + interval_s: float, + expected_count: int | str, +) -> None: + cell = MockRetryableErrorGenerator( + config=_expr_config("cell_out"), + resource_provider=_mock_provider(), + error_factory=lambda: ModelTimeoutError("read timeout"), + retryable_failures=retryable_failures, + ) + generators, graph, row_groups, tracker, buffer_mgr, _storage = _seed_plus_cell_setup(cell, num_records=num_records) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + degraded_warn_rate=0.5, + degraded_warn_window=window, + degraded_warn_interval_s=interval_s, + ) + with caplog.at_level("WARNING"): + await scheduler.run() + + n = _count_degraded_msgs(caplog) + if expected_count == "at_least_one": + assert n >= 1 + else: + assert n == expected_count + + +@pytest.mark.asyncio(loop_scope="session") +async def test_degraded_provider_warn_silent_under_threshold(caplog: pytest.LogCaptureFixture) -> None: + """Healthy runs (no errors) never emit the degraded-provider WARN.""" + scheduler, _tracker = _build_simple_pipeline(num_records=5) + with caplog.at_level("WARNING"): + await scheduler.run() + assert _count_degraded_msgs(caplog) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_degraded_provider_warn_only_counts_llm_tasks() -> None: + """The WARN window must ignore non-LLM task outcomes (samplers, expressions, etc). + + Without this, a healthy non-model column mix dilutes the retryable rate and + the WARN never fires under genuine provider stress. + """ + # Sampler-only graph: no LLM tasks → window must stay empty regardless of + # how many task outcomes feed in. + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=_mock_provider())} + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5)] + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_mgr = RowGroupBufferManager(_make_storage()) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + degraded_warn_rate=0.5, + degraded_warn_window=2, + degraded_warn_interval_s=0.0, + ) + await scheduler.run() + assert len(scheduler._recent_retryable) == 0 + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_on_before_checkpoint_callback() -> None: """on_before_checkpoint is called before each row group is checkpointed.""" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index dd72b8461..1e241a454 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -615,6 +615,8 @@ def test_build_async_preview_returns_empty_dataframe_when_row_group_is_already_f class StubScheduler: traces: list[object] = [] + early_shutdown: bool = False + partial_row_groups: tuple[int, ...] = () async def run(self) -> None: return None @@ -630,6 +632,7 @@ def mock_run_coroutine_threadsafe(coro, loop): scheduler = StubScheduler() buffer_manager = Mock() buffer_manager.has_row_group.return_value = False + buffer_manager.actual_num_records = 0 monkeypatch.setattr(builder, "_prepare_async_run", Mock(return_value=(scheduler, buffer_manager))) monkeypatch.setattr(builder_mod, "ensure_async_engine_loop", lambda: object(), raising=False) @@ -647,6 +650,26 @@ def mock_run_coroutine_threadsafe(coro, loop): buffer_manager.free_row_group.assert_not_called() +def test_reset_run_state_clears_per_run_signals(stub_resource_provider, stub_test_config_builder) -> None: + """``_reset_run_state`` must clear all per-run state so reused builders don't leak.""" + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + # Simulate prior-run state. + builder._early_shutdown = True + builder._partial_row_groups = (0, 1) + builder._actual_num_records = 42 + builder._task_traces = ["trace"] # type: ignore[list-item] + + builder._reset_run_state() + + assert builder.early_shutdown is False + assert builder.partial_row_groups == () + assert builder.actual_num_records == -1 + assert builder.task_traces == [] + + # Processor tests diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py index 5cacdc283..1a619731e 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py @@ -183,6 +183,63 @@ def test_failure_releases_slot_without_limit_change(manager: ThrottleManager) -> assert state.in_flight == 0 +def test_failure_does_not_reset_cascade_while_burst_in_flight(manager: ThrottleManager) -> None: + """Mixed-response burst (429 → 500 → 429 with multiple slots in-flight) must reduce only once. + + With a real burst of in-flight requests, an interleaved non-rate-limit + failure should NOT break the cascade - otherwise the next 429 from the + same wave would be treated as a new cascade and double-reduce the limit + even though the provider hasn't recovered between the two 429s. + """ + # Saturate to limit (4 concurrent slots). + for _ in range(4): + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.in_flight == 4 + limit_before = state.current_limit + + # First 429 from the burst: limit reduced once. + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + limit_after_first_429 = state.current_limit + assert limit_after_first_429 < limit_before + assert state.consecutive_429s == 1 + assert state.in_flight == 3 + + # Second response from the same burst: 500. With the regression, this + # would reset the cascade to 0; with the fix, in_flight > 0 keeps it at 1. + manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert state.consecutive_429s == 1, "cascade must not reset while the prior burst is still in-flight" + assert state.in_flight == 2 + + # Third response from the same burst: another 429. With the regression + # this would be treated as a new cascade and reduce the limit again. + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert state.current_limit == limit_after_first_429, "limit must not double-reduce within the same burst" + assert state.in_flight == 1 + + +def test_failure_resets_cascade_after_burst_drains(manager: ThrottleManager) -> None: + """Once the burst fully drains (in_flight == 0), the next non-RL failure breaks the cascade. + + This preserves the original PR intent for the sequential 429 → 500 → 429 + case: provider rate-limited, settled, then rate-limited again. + """ + # Saturate, then drain: one 429 then one 500 with no concurrency. + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) + assert state is not None + assert state.consecutive_429s == 1 + assert state.in_flight == 0 + + # New request after the burst drained. release_failure sees in_flight 1 → 0. + manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) + assert state.consecutive_429s == 0 + assert state.in_flight == 0 + + # --- Global cap --- diff --git a/packages/data-designer/src/data_designer/interface/__init__.py b/packages/data-designer/src/data_designer/interface/__init__.py index ad434bebc..a8a1bd61b 100644 --- a/packages/data-designer/src/data_designer/interface/__init__.py +++ b/packages/data-designer/src/data_designer/interface/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from data_designer.interface.data_designer import DataDesigner # noqa: F401 from data_designer.interface.errors import ( # noqa: F401 + DataDesignerEarlyShutdownError, DataDesignerGenerationError, DataDesignerProfilingError, ) @@ -16,6 +17,7 @@ _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "DataDesigner": ("data_designer.interface.data_designer", "DataDesigner"), + "DataDesignerEarlyShutdownError": ("data_designer.interface.errors", "DataDesignerEarlyShutdownError"), "DataDesignerGenerationError": ("data_designer.interface.errors", "DataDesignerGenerationError"), "DataDesignerProfilingError": ("data_designer.interface.errors", "DataDesignerProfilingError"), "DatasetCreationResults": ("data_designer.interface.results", "DatasetCreationResults"), diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 30f8fe108..6e64ea91c 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -61,6 +61,7 @@ ) from data_designer.engine.storage.artifact_storage import ArtifactStorage from data_designer.interface.errors import ( + DataDesignerEarlyShutdownError, DataDesignerGenerationError, DataDesignerProfilingError, ) @@ -234,6 +235,20 @@ def create( try: dataset_for_profiler = builder.artifact_storage.load_dataset_with_dropped_columns() except Exception as e: + # Distinguish "early shutdown produced zero records" from generic load failures + # so callers can react programmatically (e.g. retry on a different alias) instead + # of parsing a wrapped FileNotFoundError. The scheduler's structured signal lives + # on the builder for the duration of the run. We also require the run to have + # produced zero records: a partial-salvage run that fails to load for unrelated + # reasons (corrupt parquet, dropped-columns mismatch, filesystem hiccup) should + # surface the original cause, not a misleading "zero records" diagnosis. + if builder.early_shutdown and builder.actual_num_records == 0: + raise DataDesignerEarlyShutdownError( + "🛑 Generation produced zero records — early shutdown was triggered. " + "The non-retryable error rate exceeded the configured threshold; check the " + "warnings above (and any 'Provider showing degraded performance' logs) for " + "the contributing failures." + ) from e raise DataDesignerGenerationError( f"🛑 Failed to load generated dataset — all records may have been dropped " f"due to generation failures. Check the warnings above for details. Original error: {e}" @@ -243,6 +258,15 @@ def create( # practice load_dataset_with_dropped_columns() would raise before returning a # zero-row DataFrame. This guard protects against future changes to that contract. if len(dataset_for_profiler) == 0: + # Mirror the load-failure guard above: only raise the typed error when + # the run actually produced zero records. A partial-salvage run that + # somehow returns an empty DF for unrelated reasons should surface the + # generic error. + if builder.early_shutdown and builder.actual_num_records == 0: + raise DataDesignerEarlyShutdownError( + "🛑 Dataset is empty — early shutdown was triggered before any records " + "could complete. Check the warnings above for the contributing failures." + ) raise DataDesignerGenerationError( "🛑 Dataset is empty — all records were dropped due to generation failures. " "Check the warnings above for details on which columns failed." @@ -288,6 +312,8 @@ def preview( Raises: DataDesignerGenerationError: If an error occurs during preview dataset generation. + DataDesignerEarlyShutdownError: If preview terminated via the early-shutdown gate + with zero records produced. Subclass of ``DataDesignerGenerationError``. DataDesignerProfilingError: If an error occurs during preview dataset profiling. """ logger.info(f"{RandomEmoji.previewing()} Preview generation in progress") @@ -304,6 +330,14 @@ def preview( raise DataDesignerGenerationError(f"🛑 Error generating preview dataset: {e}") from e if len(processed_dataset) == 0: + # Mirror the create() path: distinguish "early shutdown produced zero + # records" from generic empty-dataset failures so callers can react + # programmatically. + if builder.early_shutdown and builder.actual_num_records == 0: + raise DataDesignerEarlyShutdownError( + "🛑 Preview is empty — early shutdown was triggered before any records " + "could complete. Check the warnings above for the contributing failures." + ) raise DataDesignerGenerationError( "🛑 Dataset is empty — all records were dropped due to generation or processing failures. " "Check the warnings above for details on which columns failed." diff --git a/packages/data-designer/src/data_designer/interface/errors.py b/packages/data-designer/src/data_designer/interface/errors.py index 1e5f5050c..3b113ef9b 100644 --- a/packages/data-designer/src/data_designer/interface/errors.py +++ b/packages/data-designer/src/data_designer/interface/errors.py @@ -14,5 +14,15 @@ class DataDesignerGenerationError(DataDesignerError): """Raised for errors related to a Data Designer dataset generation.""" +class DataDesignerEarlyShutdownError(DataDesignerGenerationError): + """Raised when a run terminated via early shutdown and produced no records. + + Subclass of ``DataDesignerGenerationError`` so existing handlers still catch + it; callers that want to distinguish the early-shutdown case (e.g. to retry + with a different model alias or surface a degraded-provider message to the + user) can catch this specific type. + """ + + class InvalidBufferValueError(DataDesignerError): """Raised for errors related to an invalid buffer value.""" diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index dc56b1a74..b59b8d669 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -3,12 +3,13 @@ from __future__ import annotations +import contextlib import json import logging from datetime import datetime from pathlib import Path from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest from pydantic import ValidationError @@ -39,7 +40,11 @@ from data_designer.engine.testing.seed_readers import LineFanoutDirectorySeedReader from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader from data_designer.interface.data_designer import DataDesigner -from data_designer.interface.errors import DataDesignerGenerationError, DataDesignerProfilingError +from data_designer.interface.errors import ( + DataDesignerEarlyShutdownError, + DataDesignerGenerationError, + DataDesignerProfilingError, +) class CustomDirectorySeedReader(FileSystemSeedReader[DirectorySeedSource]): @@ -635,50 +640,153 @@ def test_preview_raises_error_when_profiler_fails( assert isinstance(exc_info.value.__cause__, ValueError) -def test_create_raises_generation_error_when_dataset_is_empty( - stub_artifact_path, stub_model_providers, stub_sampler_only_config_builder, stub_managed_assets_path -): - """When all records are dropped during generation, create should raise - DataDesignerGenerationError with a clear message instead of a misleading profiler error. - """ - data_designer = DataDesigner( +def _patch_builder_state(*, early_shutdown: bool, actual_num_records: int = 0) -> contextlib.ExitStack: + """Patch DatasetBuilder.early_shutdown / actual_num_records as PropertyMocks.""" + stack = contextlib.ExitStack() + stack.enter_context( + patch( + "data_designer.engine.dataset_builders.dataset_builder.DatasetBuilder.early_shutdown", + new_callable=PropertyMock, + return_value=early_shutdown, + ) + ) + stack.enter_context( + patch( + "data_designer.engine.dataset_builders.dataset_builder.DatasetBuilder.actual_num_records", + new_callable=PropertyMock, + return_value=actual_num_records, + ) + ) + return stack + + +def _make_data_designer( + stub_artifact_path: Path, + stub_model_providers: list[ModelProvider], + stub_managed_assets_path: Path, +) -> DataDesigner: + return DataDesigner( artifact_path=stub_artifact_path, model_providers=stub_model_providers, secret_resolver=PlaintextResolver(), managed_assets_path=stub_managed_assets_path, ) - with patch( - "data_designer.engine.storage.artifact_storage.ArtifactStorage.load_dataset_with_dropped_columns", - return_value=lazy.pd.DataFrame(), - ): - with pytest.raises(DataDesignerGenerationError, match="Dataset is empty"): - data_designer.create(stub_sampler_only_config_builder, num_records=1) - -def test_create_raises_generation_error_when_load_dataset_fails( +# Matrix of error-dispatch behavior in create() when the load step doesn't return a +# usable dataset. Two side-effect modes (load raises FileNotFoundError vs. load +# returns an empty DF) crossed with three builder states (no shutdown, shutdown +# with zero records, shutdown with partial salvage). Each case asserts exactly +# which error type create() surfaces and the message it carries. +@pytest.mark.parametrize( + "load_side_effect,early_shutdown,actual_num_records,expected_exc,match,expect_filenotfound_cause", + [ + # Load raises FileNotFoundError → "Failed to load generated dataset" path. + pytest.param( + "raises", + False, + -1, + DataDesignerGenerationError, + "Failed to load generated dataset", + True, + id="load_fails_no_shutdown", + ), + pytest.param( + "raises", + True, + 0, + DataDesignerEarlyShutdownError, + "early shutdown was triggered", + True, + id="load_fails_shutdown_zero_records", + ), + pytest.param( + "raises", + True, + 7, + DataDesignerGenerationError, + "Failed to load generated dataset", + True, + id="load_fails_shutdown_partial_salvage", + ), + # Load returns empty DF → "Dataset is empty" defensive guard. + pytest.param( + "empty_df", + False, + -1, + DataDesignerGenerationError, + "Dataset is empty", + False, + id="empty_df_no_shutdown", + ), + pytest.param( + "empty_df", + True, + 0, + DataDesignerEarlyShutdownError, + "early shutdown was triggered", + False, + id="empty_df_shutdown_zero_records", + ), + pytest.param( + "empty_df", + True, + 7, + DataDesignerGenerationError, + "Dataset is empty", + False, + id="empty_df_shutdown_partial_salvage", + ), + ], +) +def test_create_error_dispatch_on_load_outcome( stub_artifact_path: Path, stub_model_providers: list[ModelProvider], stub_sampler_only_config_builder: DataDesignerConfigBuilder, stub_managed_assets_path: Path, + load_side_effect: str, + early_shutdown: bool, + actual_num_records: int, + expected_exc: type[Exception], + match: str, + expect_filenotfound_cause: bool, ) -> None: - """When no parquet was written (e.g. all records dropped), load_dataset_with_dropped_columns - raises an exception. create() should surface this as DataDesignerGenerationError, not - DataDesignerProfilingError. + """create() picks the right error type based on (load outcome × builder state). + + The typed ``DataDesignerEarlyShutdownError`` only fires when the gate tripped + AND zero records were produced. Partial-salvage runs that fail to load (or + return empty for unrelated reasons) fall through to the generic error so the + real cause isn't masked. """ - data_designer = DataDesigner( - artifact_path=stub_artifact_path, - model_providers=stub_model_providers, - secret_resolver=PlaintextResolver(), - managed_assets_path=stub_managed_assets_path, + data_designer = _make_data_designer(stub_artifact_path, stub_model_providers, stub_managed_assets_path) + + if load_side_effect == "raises": + load_patch = patch( + "data_designer.engine.storage.artifact_storage.ArtifactStorage.load_dataset_with_dropped_columns", + side_effect=FileNotFoundError("No parquet files found"), + ) + else: + load_patch = patch( + "data_designer.engine.storage.artifact_storage.ArtifactStorage.load_dataset_with_dropped_columns", + return_value=lazy.pd.DataFrame(), + ) + state_patch = ( + _patch_builder_state(early_shutdown=early_shutdown, actual_num_records=actual_num_records) + if early_shutdown + else contextlib.nullcontext() ) - with patch( - "data_designer.engine.storage.artifact_storage.ArtifactStorage.load_dataset_with_dropped_columns", - side_effect=FileNotFoundError("No parquet files found"), - ): - with pytest.raises(DataDesignerGenerationError, match="Failed to load generated dataset") as exc_info: - data_designer.create(stub_sampler_only_config_builder, num_records=1) + with load_patch, state_patch: + with pytest.raises(expected_exc, match=match) as exc_info: + data_designer.create(stub_sampler_only_config_builder, num_records=10) + + # Subclass relationship is the contract callers depend on - existing handlers + # for DataDesignerGenerationError must still catch the typed subclass. + if expected_exc is DataDesignerEarlyShutdownError: + assert isinstance(exc_info.value, DataDesignerGenerationError) + else: + assert not isinstance(exc_info.value, DataDesignerEarlyShutdownError) + if expect_filenotfound_cause: assert isinstance(exc_info.value.__cause__, FileNotFoundError) @@ -703,6 +811,47 @@ def test_preview_raises_generation_error_when_dataset_is_empty( data_designer.preview(stub_sampler_only_config_builder, num_records=1) +def test_preview_raises_early_shutdown_error_on_empty_after_shutdown( + stub_artifact_path: Path, + stub_model_providers: list[ModelProvider], + stub_sampler_only_config_builder: DataDesignerConfigBuilder, + stub_managed_assets_path: Path, +) -> None: + """Preview mirrors create(): typed early-shutdown error fires when shutdown produced zero records.""" + data_designer = _make_data_designer(stub_artifact_path, stub_model_providers, stub_managed_assets_path) + + with ( + patch( + "data_designer.engine.dataset_builders.dataset_builder.DatasetBuilder.process_preview", + return_value=lazy.pd.DataFrame(), + ), + _patch_builder_state(early_shutdown=True, actual_num_records=0), + ): + with pytest.raises(DataDesignerEarlyShutdownError, match="early shutdown was triggered"): + data_designer.preview(stub_sampler_only_config_builder, num_records=1) + + +def test_preview_raises_generic_error_when_partial_then_empty( + stub_artifact_path: Path, + stub_model_providers: list[ModelProvider], + stub_sampler_only_config_builder: DataDesignerConfigBuilder, + stub_managed_assets_path: Path, +) -> None: + """Preview falls through to the generic error when records were salvaged.""" + data_designer = _make_data_designer(stub_artifact_path, stub_model_providers, stub_managed_assets_path) + + with ( + patch( + "data_designer.engine.dataset_builders.dataset_builder.DatasetBuilder.process_preview", + return_value=lazy.pd.DataFrame(), + ), + _patch_builder_state(early_shutdown=True, actual_num_records=3), + ): + with pytest.raises(DataDesignerGenerationError, match="Dataset is empty") as exc_info: + data_designer.preview(stub_sampler_only_config_builder, num_records=10) + assert not isinstance(exc_info.value, DataDesignerEarlyShutdownError) + + def test_create_logs_secure_jinja_rendering_mode( stub_artifact_path: Path, stub_model_providers: list[ModelProvider],