Skip to content

Commit 763eedd

Browse files
committed
fix(async): salvage partial row groups on early shutdown
Before: when the early-shutdown gate fired, any row group still in flight stayed in `_rg_states` un-checkpointed. The buffer manager later raised `FileNotFoundError` when the builder tried to read the finalized parquet. User-visible result: `0 records produced`. After: a new `_finalize_after_shutdown` step runs in `run()`'s finally block, after `_cancel_workers` has drained in-flight tasks (Codex caveat: in-flight `from_scratch`/`batch` tasks must not be allowed to write into a buffer that's being finalized). For each remaining row group it drops rows that aren't fully complete, then delegates to the existing `_checkpoint_completed_row_groups` so the buffer manager's zero-survivor handling (skip empty parquet, free buffer) kicks in unchanged. Also surfaces partial completion as a structured signal: scheduler exposes `early_shutdown: bool` and `partial_row_groups: tuple[int, ...]` properties so callers can detect partial completion programmatically rather than parsing log lines. Builder uses this to emit a more specific WARN distinguishing early shutdown from non-shutdown drops. Refs #575.
1 parent 6a85d1e commit 763eedd

3 files changed

Lines changed: 248 additions & 4 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def __init__(
195195
self._recent_retryable: deque[bool] = deque(maxlen=degraded_warn_window)
196196
self._last_degraded_warn_at: float = 0.0
197197

198+
# Row groups that were partially salvaged after early shutdown
199+
# (i.e., some rows complete, some incomplete-then-dropped). Surfaced
200+
# via the partial_row_groups property as a structured signal.
201+
self._partial_row_groups: list[int] = []
202+
198203
# Pre-compute row-group sizes for O(1) lookup
199204
self._rg_size_map: dict[int, int] = dict(row_groups)
200205

@@ -239,6 +244,20 @@ def _setup_async_progress_reporter(
239244
def active_worker_count(self) -> int:
240245
return sum(1 for t in self._worker_tasks if not t.done())
241246

247+
@property
248+
def early_shutdown(self) -> bool:
249+
"""True if the run terminated via the early-shutdown gate."""
250+
return self._early_shutdown
251+
252+
@property
253+
def partial_row_groups(self) -> tuple[int, ...]:
254+
"""Row group ids that were partially salvaged after early shutdown.
255+
256+
Empty unless ``early_shutdown`` is True. Each id had some rows
257+
complete and the rest dropped before checkpointing.
258+
"""
259+
return tuple(self._partial_row_groups)
260+
242261
def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task:
243262
"""Create a tracked worker task that auto-removes itself on completion."""
244263
task = asyncio.create_task(coro)
@@ -304,6 +323,11 @@ async def run(self) -> None:
304323
with contextlib.suppress(asyncio.CancelledError):
305324
await admission_task
306325
await asyncio.shield(self._cancel_workers())
326+
# Salvage partially-complete row groups left over from early
327+
# shutdown. Must run AFTER _cancel_workers - in-flight tasks
328+
# could otherwise write into a buffer that's being finalized.
329+
if self._early_shutdown and self._rg_states:
330+
self._finalize_after_shutdown(all_columns)
307331

308332
if self._reporter:
309333
self._reporter.log_final()
@@ -552,6 +576,37 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None:
552576
checkpointed = {rg_id for rg_id, _ in completed}
553577
self._deferred = [t for t in self._deferred if t.row_group not in checkpointed]
554578

579+
def _finalize_after_shutdown(self, all_columns: list[str]) -> None:
580+
"""Salvage row groups left in flight when early shutdown fired.
581+
582+
For each remaining row group, drop rows that aren't fully complete
583+
(and weren't already dropped); after that, ``is_row_group_complete``
584+
is true by construction over the surviving rows, so delegating to
585+
``_checkpoint_completed_row_groups`` writes survivors and frees
586+
zero-survivor groups via the buffer manager's existing logic.
587+
"""
588+
for rg_id in list(self._rg_states.keys()):
589+
rg_size = self._rg_states[rg_id].size
590+
had_incomplete = False
591+
for ri in range(rg_size):
592+
if self._tracker.is_dropped(rg_id, ri):
593+
continue
594+
if all(
595+
self._tracker.is_complete(SliceRef(column=col, row_group=rg_id, row_index=ri))
596+
for col in all_columns
597+
):
598+
continue
599+
had_incomplete = True
600+
self._drop_row(rg_id, ri)
601+
if had_incomplete:
602+
survivors = sum(1 for ri in range(rg_size) if not self._tracker.is_dropped(rg_id, ri))
603+
if survivors > 0:
604+
self._partial_row_groups.append(rg_id)
605+
logger.warning(f"Row group {rg_id}: salvaging {survivors} of {rg_size} rows after early shutdown.")
606+
else:
607+
logger.warning(f"Row group {rg_id}: 0 of {rg_size} rows survived early shutdown - skipping write.")
608+
self._checkpoint_completed_row_groups(all_columns)
609+
555610
def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None:
556611
"""Run pre-batch callbacks for row groups whose seeds just completed."""
557612
for rg_id, state in list(self._rg_states.items()):

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,18 @@ def on_complete(final_path: Path | str | None) -> None:
341341
actual = buffer_manager.actual_num_records
342342
if actual < num_records:
343343
pct = actual / num_records * 100 if num_records > 0 else 0
344-
logger.warning(
345-
f"⚠️ Generated {actual} of {num_records} requested records ({pct:.0f}%). "
346-
"The dataset may be incomplete due to errors or early shutdown."
347-
)
344+
base = f"⚠️ Generated {actual} of {num_records} requested records ({pct:.0f}%). "
345+
if scheduler.early_shutdown:
346+
partial = scheduler.partial_row_groups
347+
detail = (
348+
f"Early shutdown was triggered (non-retryable error rate exceeded threshold); "
349+
f"{len(partial)} row group(s) salvaged with partial rows."
350+
if partial
351+
else "Early shutdown was triggered (non-retryable error rate exceeded threshold)."
352+
)
353+
logger.warning(base + detail)
354+
else:
355+
logger.warning(base + "The dataset may be incomplete due to dropped rows.")
348356

349357
def _prepare_async_run(
350358
self,

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,52 @@ def generate(self, data: dict) -> dict:
173173
return data
174174

175175

176+
class MockSelectiveFailGenerator(ColumnGenerator[ExpressionColumnConfig]):
177+
"""Cell generator with deterministic per-seed behavior.
178+
179+
- Seeds in ``fail_on_seeds``: raise a non-retryable ``ValueError`` immediately.
180+
- Seeds in ``slow_seeds``: block on ``slow_event`` (or ``asyncio.sleep``) so
181+
they remain in-flight when the early-shutdown gate fires.
182+
- All others: succeed.
183+
"""
184+
185+
def __init__(
186+
self,
187+
*args: Any,
188+
fail_on_seeds: set[int] = frozenset(),
189+
slow_seeds: set[int] = frozenset(),
190+
slow_timeout_s: float = 5.0,
191+
**kwargs: Any,
192+
) -> None:
193+
super().__init__(*args, **kwargs)
194+
self._fail = set(fail_on_seeds)
195+
self._slow = set(slow_seeds)
196+
self._slow_timeout_s = slow_timeout_s
197+
198+
@staticmethod
199+
def get_generation_strategy() -> GenerationStrategy:
200+
return GenerationStrategy.CELL_BY_CELL
201+
202+
async def agenerate(self, data: dict) -> dict:
203+
seed = data.get("seed")
204+
if seed in self._fail:
205+
raise ValueError(f"non-retryable on seed={seed}")
206+
if seed in self._slow:
207+
try:
208+
await asyncio.sleep(self._slow_timeout_s)
209+
except asyncio.CancelledError:
210+
raise
211+
data[self.config.name] = f"ok_{seed}"
212+
return data
213+
214+
def generate(self, data: dict) -> dict:
215+
seed = data.get("seed")
216+
if seed in self._fail:
217+
raise ValueError(f"non-retryable on seed={seed}")
218+
data[self.config.name] = f"ok_{seed}"
219+
return data
220+
221+
176222
class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]):
177223
"""Generator that raises a parametrizable retryable error then succeeds."""
178224

@@ -722,6 +768,141 @@ async def test_scheduler_error_rate_shutdown() -> None:
722768

723769
# Early shutdown: not all rows should be checkpointed (some row groups incomplete)
724770
assert buffer_mgr.actual_num_records < 10
771+
# No leftover unfinished row groups (finalize-after-shutdown drains them).
772+
assert not scheduler._rg_states
773+
774+
775+
@pytest.mark.asyncio(loop_scope="session")
776+
async def test_partial_row_group_salvaged_after_early_shutdown() -> None:
777+
"""Mid-run shutdown drops incomplete rows and checkpoints survivors."""
778+
provider = _mock_provider()
779+
configs = [
780+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
781+
LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
782+
]
783+
strategies = {
784+
"seed": GenerationStrategy.FULL_COLUMN,
785+
"cell_out": GenerationStrategy.CELL_BY_CELL,
786+
}
787+
# 3 succeed (0,1,2), 3 fail non-retryable (5,6,7), 4 stay in-flight (3,4,8,9)
788+
# until cancellation. Window=4, rate=0.5 → gate trips after ~3-5 outcomes.
789+
generators = {
790+
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
791+
"cell_out": MockSelectiveFailGenerator(
792+
config=_expr_config("cell_out"),
793+
resource_provider=provider,
794+
fail_on_seeds={5, 6, 7},
795+
slow_seeds={3, 4, 8, 9},
796+
),
797+
}
798+
799+
graph = ExecutionGraph.create(configs, strategies)
800+
row_groups = [(0, 10)]
801+
tracker = CompletionTracker.with_graph(graph, row_groups)
802+
803+
storage = MagicMock()
804+
storage.dataset_name = "test"
805+
storage.get_file_paths.return_value = {}
806+
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
807+
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
808+
buffer_mgr = RowGroupBufferManager(storage)
809+
810+
finalized: list[int] = []
811+
812+
def on_finalize(rg_id: int) -> None:
813+
buffer_mgr.checkpoint_row_group(rg_id)
814+
finalized.append(rg_id)
815+
816+
scheduler = AsyncTaskScheduler(
817+
generators=generators,
818+
graph=graph,
819+
tracker=tracker,
820+
row_groups=row_groups,
821+
buffer_manager=buffer_mgr,
822+
on_finalize_row_group=on_finalize,
823+
shutdown_error_rate=0.5,
824+
shutdown_error_window=4,
825+
)
826+
await scheduler.run()
827+
828+
assert scheduler.early_shutdown
829+
# The row group survived with the 3 fast successes; the in-flight rows were
830+
# cancelled and dropped by _finalize_after_shutdown.
831+
assert 0 in finalized
832+
assert scheduler.partial_row_groups == (0,)
833+
# Exactly 3 rows survived (seeds 0, 1, 2).
834+
assert buffer_mgr.actual_num_records == 3
835+
836+
837+
@pytest.mark.asyncio(loop_scope="session")
838+
async def test_zero_survivor_shutdown_does_not_raise() -> None:
839+
"""If every row is dropped at shutdown, the row group is freed without writing parquet."""
840+
provider = _mock_provider()
841+
configs = [
842+
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
843+
LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
844+
]
845+
strategies = {
846+
"seed": GenerationStrategy.FULL_COLUMN,
847+
"cell_out": GenerationStrategy.CELL_BY_CELL,
848+
}
849+
# All 5 seeds fail non-retryable → all rows dropped before any can complete.
850+
generators = {
851+
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
852+
"cell_out": MockSelectiveFailGenerator(
853+
config=_expr_config("cell_out"),
854+
resource_provider=provider,
855+
fail_on_seeds=set(range(5)),
856+
),
857+
}
858+
859+
graph = ExecutionGraph.create(configs, strategies)
860+
row_groups = [(0, 5)]
861+
tracker = CompletionTracker.with_graph(graph, row_groups)
862+
863+
storage = MagicMock()
864+
storage.dataset_name = "test"
865+
storage.get_file_paths.return_value = {}
866+
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
867+
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
868+
buffer_mgr = RowGroupBufferManager(storage)
869+
870+
finalized: list[int] = []
871+
872+
def on_finalize(rg_id: int) -> None:
873+
buffer_mgr.checkpoint_row_group(rg_id)
874+
finalized.append(rg_id)
875+
876+
scheduler = AsyncTaskScheduler(
877+
generators=generators,
878+
graph=graph,
879+
tracker=tracker,
880+
row_groups=row_groups,
881+
buffer_manager=buffer_mgr,
882+
on_finalize_row_group=on_finalize,
883+
shutdown_error_rate=0.5,
884+
shutdown_error_window=2,
885+
)
886+
# Must not raise (no FileNotFoundError, no DataDesignerGenerationError).
887+
await scheduler.run()
888+
889+
assert scheduler.early_shutdown
890+
assert buffer_mgr.actual_num_records == 0
891+
# All rows dropped → checkpoint path frees buffer without writing; on_finalize
892+
# is *not* called because every row was dropped before survivors could exist.
893+
assert finalized == []
894+
# No partial-row-groups recorded — there were no incomplete-but-not-dropped rows.
895+
assert scheduler.partial_row_groups == ()
896+
storage.write_batch_to_parquet_file.assert_not_called()
897+
898+
899+
@pytest.mark.asyncio(loop_scope="session")
900+
async def test_healthy_run_has_no_partial_signal() -> None:
901+
"""Successful run leaves early_shutdown=False and partial_row_groups empty."""
902+
scheduler, _tracker = _build_simple_pipeline(num_records=3)
903+
await scheduler.run()
904+
assert not scheduler.early_shutdown
905+
assert scheduler.partial_row_groups == ()
725906

726907

727908
@pytest.mark.asyncio(loop_scope="session")

0 commit comments

Comments
 (0)