Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
from data_designer.engine.column_generators.generators.base import SYNC_BRIDGE_TIMEOUT, ColumnGenerator
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError
from data_designer.logging import LOG_INDENT

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,8 +69,13 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]:
return future.result(timeout=SYNC_BRIDGE_TIMEOUT)
except concurrent.futures.TimeoutError as exc:
future.cancel()
logger.warning("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT)
raise TimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc
# Demoted to debug: the raised ModelTimeoutError already surfaces
# the timeout at the scheduler with full context, and the throttled
# degraded-provider WARN is the user-facing signal under sustained
# bridge timeouts. Per-event WARN was noise on top of those.
logger.debug("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT)
# Raise as ModelTimeoutError so the scheduler classifies it retryable.
raise ModelTimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc

def __getattr__(self, name: str) -> Any:
return getattr(object.__getattribute__(self, "_facade"), name)
Expand Down Expand Up @@ -147,6 +153,10 @@ async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | li
result = await self._ainvoke_generator_function(data)
except CustomColumnGenerationError:
raise
except RETRYABLE_MODEL_ERRORS:
# Preserve retryability so the scheduler can salvage these
# instead of counting them toward the early-shutdown gate.
raise
except Exception as e:
logger.warning(
f"⚠️ Custom generator function {self.config.generator_function.__name__!r} "
Expand Down Expand Up @@ -193,6 +203,10 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
result = self._invoke_generator_function(data)
except CustomColumnGenerationError:
raise
except RETRYABLE_MODEL_ERRORS:
# Preserve retryability so the scheduler can salvage these
# instead of counting them toward the early-shutdown gate.
raise
except Exception as e:
if not is_dataframe:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@
)
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar
from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace
from data_designer.engine.models.errors import (
ModelAPIConnectionError,
ModelInternalServerError,
ModelRateLimitError,
ModelTimeoutError,
)
from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS

if TYPE_CHECKING:
from data_designer.engine.column_generators.generators.base import ColumnGenerator
Expand All @@ -47,12 +42,13 @@
DEFAULT_TASK_POOL_SIZE: int = 256
LLM_WAIT_POOL_MULTIPLIER: int = 2

_RETRYABLE_MODEL_ERRORS = (
ModelRateLimitError,
ModelTimeoutError,
ModelInternalServerError,
ModelAPIConnectionError,
)
# Degraded-provider WARN: emit at most one warning per interval when the
# rolling fraction of retryable errors exceeds the threshold. Distinct from
# the early-shutdown gate (which fires on non-retryable errors).
# TODO: thread these through RunConfig so users can tune them per run.
DEGRADED_WARN_RATE: float = 0.5
DEGRADED_WARN_WINDOW: int = 20
DEGRADED_WARN_INTERVAL_S: float = 60.0


class TrackingSemaphore(asyncio.Semaphore):
Expand Down Expand Up @@ -105,6 +101,9 @@ def __init__(
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
degraded_warn_rate: float = DEGRADED_WARN_RATE,
degraded_warn_window: int = DEGRADED_WARN_WINDOW,
degraded_warn_interval_s: float = DEGRADED_WARN_INTERVAL_S,
trace: bool = False,
num_records: int = 0,
buffer_size: int = 0,
Expand Down Expand Up @@ -177,6 +176,22 @@ def __init__(
self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window)
self._all_rgs_admitted = False

# Degraded-provider WARN: separate window tracking retryable-vs-not for
# every outcome (success or failure), throttled to one log per interval.
self._degraded_warn_rate = degraded_warn_rate
self._degraded_warn_window = degraded_warn_window
self._degraded_warn_interval_s = degraded_warn_interval_s
self._recent_retryable: deque[bool] = deque(maxlen=degraded_warn_window)
# Initialize to -inf so the first WARN is always emitted regardless of
# the monotonic clock's absolute value (which can be near-zero on freshly
# booted CI runners).
self._last_degraded_warn_at: float = float("-inf")

# Row groups that were partially salvaged after early shutdown
# (i.e., some rows complete, some incomplete-then-dropped). Surfaced
# via the partial_row_groups property as a structured signal.
self._partial_row_groups: list[int] = []

# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)

Expand Down Expand Up @@ -221,6 +236,20 @@ def _setup_async_progress_reporter(
def active_worker_count(self) -> int:
return sum(1 for t in self._worker_tasks if not t.done())

@property
def early_shutdown(self) -> bool:
"""True if the run terminated via the early-shutdown gate."""
return self._early_shutdown

@property
def partial_row_groups(self) -> tuple[int, ...]:
"""Row group ids that were partially salvaged after early shutdown.

Empty unless ``early_shutdown`` is True. Each id had some rows
complete and the rest dropped before checkpointing.
"""
return tuple(self._partial_row_groups)

def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task:
"""Create a tracked worker task that auto-removes itself on completion."""
task = asyncio.create_task(coro)
Expand Down Expand Up @@ -270,13 +299,9 @@ async def run(self) -> None:
# Launch admission as a background task so it interleaves with dispatch.
admission_task = asyncio.create_task(self._admit_row_groups())

dispatch_error: BaseException | None = None
try:
# Main dispatch loop
await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns)
except BaseException as exc:
dispatch_error = exc
raise
finally:
# Always cancel admission + drain in-flight workers, regardless
# of how the dispatch loop exited (normal, early shutdown,
Expand All @@ -286,11 +311,18 @@ async def run(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
await admission_task
await asyncio.shield(self._cancel_workers())

# Salvage partially-complete row groups left over from early
# shutdown. Must run AFTER _cancel_workers - in-flight tasks
# could otherwise write into a buffer that's being finalized.
if self._early_shutdown and self._rg_states:
self._finalize_after_shutdown(all_columns)

# Reached only on the clean-exit path; an exception in the
# dispatch loop or the finally block propagates and skips this.
if self._reporter:
self._reporter.log_final()

if self._rg_states and dispatch_error is None:
if self._rg_states:
incomplete = list(self._rg_states)
logger.error(
f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. "
Expand All @@ -306,7 +338,7 @@ async def _main_dispatch_loop(
"""Core dispatch loop extracted from ``run()``."""
while True:
if self._early_shutdown:
logger.warning("Early shutdown triggered - error rate exceeded threshold")
logger.warning("Early shutdown triggered - non-retryable error rate exceeded threshold")
if self._deferred:
await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns)
self._checkpoint_completed_row_groups(all_columns)
Expand Down Expand Up @@ -534,6 +566,44 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None:
checkpointed = {rg_id for rg_id, _ in completed}
self._deferred = [t for t in self._deferred if t.row_group not in checkpointed]

def _finalize_after_shutdown(self, all_columns: list[str]) -> None:
"""Salvage row groups left in flight when early shutdown fired.

For each remaining row group, drop rows that aren't fully complete
(and weren't already dropped); after that, ``is_row_group_complete``
is true by construction over the surviving rows, so delegating to
``_checkpoint_completed_row_groups`` writes survivors and frees
zero-survivor groups via the buffer manager's existing logic.

Note on processors: ``_checkpoint_completed_row_groups`` calls
``on_before_checkpoint`` (post-batch) but never ``on_seeds_complete``
(pre-batch). If the gate fires before seeds completed for a row
group, that row group's pre-batch processor never ran. Survivors
are checkpointed without it. This is the existing contract for
partial-row-group salvage.
"""
for rg_id in list(self._rg_states.keys()):
rg_size = self._rg_states[rg_id].size
had_incomplete = False
for ri in range(rg_size):
if self._tracker.is_dropped(rg_id, ri):
continue
if all(
self._tracker.is_complete(SliceRef(column=col, row_group=rg_id, row_index=ri))
for col in all_columns
):
continue
had_incomplete = True
self._drop_row(rg_id, ri)
if had_incomplete:
survivors = sum(1 for ri in range(rg_size) if not self._tracker.is_dropped(rg_id, ri))
if survivors > 0:
self._partial_row_groups.append(rg_id)
logger.warning(f"Row group {rg_id}: salvaging {survivors} of {rg_size} rows after early shutdown.")
else:
logger.warning(f"Row group {rg_id}: 0 of {rg_size} rows survived early shutdown - skipping write.")
self._checkpoint_completed_row_groups(all_columns)

def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None:
"""Run pre-batch callbacks for row groups whose seeds just completed."""
for rg_id, state in list(self._rg_states.items()):
Expand Down Expand Up @@ -606,6 +676,35 @@ def _check_error_rate(self, *, success: bool) -> None:
if errors / self._shutdown_error_window >= self._shutdown_error_rate:
self._early_shutdown = True

def _record_retryable_outcome(self, *, retryable: bool) -> None:
"""Track retryable-error rate and emit a throttled WARN under provider degradation.

Distinct from ``_check_error_rate``: every LLM-bound task outcome (success
or failure) feeds this window so the rate reflects the provider's overall
health, not just the error mix. The call site filters on ``is_llm`` so
non-LLM tasks (samplers, expressions, non-LLM customs) don't dilute the
rate. Only retryable errors (rate-limit, timeout, 5xx, connection) count
toward the rate; non-retryable failures register as 0.
"""
if self._degraded_warn_window <= 0:
return
self._recent_retryable.append(retryable)
if len(self._recent_retryable) < self._degraded_warn_window:
return
rate = sum(self._recent_retryable) / self._degraded_warn_window
if rate < self._degraded_warn_rate:
return
now = time.monotonic()
if now - self._last_degraded_warn_at < self._degraded_warn_interval_s:
return
self._last_degraded_warn_at = now
pct = int(round(rate * 100))
logger.warning(
f"Provider showing degraded performance: {pct}% of last {self._degraded_warn_window} "
"task outcomes were retryable errors (rate-limit, timeout, 5xx, connection). "
"Run may take longer than expected; salvage will retry these."
)

async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
"""Dispatch from_scratch tasks for a row group."""
self._rg_states[rg_id].seeds_dispatched = True
Expand Down Expand Up @@ -730,6 +829,12 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
self._tracker.mark_cell_complete(col, task.row_group, task.row_index)

self._check_error_rate(success=True)
# The degraded-provider WARN is provider-scoped: only feed the
# window from LLM-bound tasks so a healthy non-model task mix
# (samplers, expressions, non-LLM customs) doesn't dilute the
# rate and silence the WARN under genuine provider stress.
if is_llm:
self._record_retryable_outcome(retryable=False)
if self._reporter:
if cell_skipped:
self._reporter.record_skipped(task.column)
Expand All @@ -739,9 +844,15 @@ async def _execute_task_inner_impl(self, task: Task) -> None:
trace.status = "ok"

except Exception as exc:
if not isinstance(exc, ModelRateLimitError):
self._check_error_rate(success=False)
retryable = self._is_retryable(exc)
# Only non-retryable errors (auth, schema, code bugs) count toward
# the early-shutdown gate. Retryable errors (rate-limit, timeout,
# transient 5xx, connection blips) cluster under provider degradation
# and would otherwise trip the gate even when salvage could recover.
if not retryable:
self._check_error_rate(success=False)
if is_llm:
self._record_retryable_outcome(retryable=retryable)
if not retryable and self._reporter:
self._reporter.record_failure(task.column)
if self._trace and trace:
Expand Down Expand Up @@ -930,7 +1041,7 @@ def get_semaphore_permits(self) -> tuple[int, int]:
@staticmethod
def _is_retryable(exc: Exception) -> bool:
"""Classify whether an exception is retryable."""
return isinstance(exc, _RETRYABLE_MODEL_ERRORS)
return isinstance(exc, RETRYABLE_MODEL_ERRORS)


def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]:
Expand Down
Loading
Loading