@@ -173,6 +173,52 @@ def generate(self, data: dict) -> dict:
173173 return data
174174
175175
176+ class MockSelectiveFailGenerator (ColumnGenerator [ExpressionColumnConfig ]):
177+ """Cell generator with deterministic per-seed behavior.
178+
179+ - Seeds in ``fail_on_seeds``: raise a non-retryable ``ValueError`` immediately.
180+ - Seeds in ``slow_seeds``: block on ``slow_event`` (or ``asyncio.sleep``) so
181+ they remain in-flight when the early-shutdown gate fires.
182+ - All others: succeed.
183+ """
184+
185+ def __init__ (
186+ self ,
187+ * args : Any ,
188+ fail_on_seeds : set [int ] = frozenset (),
189+ slow_seeds : set [int ] = frozenset (),
190+ slow_timeout_s : float = 5.0 ,
191+ ** kwargs : Any ,
192+ ) -> None :
193+ super ().__init__ (* args , ** kwargs )
194+ self ._fail = set (fail_on_seeds )
195+ self ._slow = set (slow_seeds )
196+ self ._slow_timeout_s = slow_timeout_s
197+
198+ @staticmethod
199+ def get_generation_strategy () -> GenerationStrategy :
200+ return GenerationStrategy .CELL_BY_CELL
201+
202+ async def agenerate (self , data : dict ) -> dict :
203+ seed = data .get ("seed" )
204+ if seed in self ._fail :
205+ raise ValueError (f"non-retryable on seed={ seed } " )
206+ if seed in self ._slow :
207+ try :
208+ await asyncio .sleep (self ._slow_timeout_s )
209+ except asyncio .CancelledError :
210+ raise
211+ data [self .config .name ] = f"ok_{ seed } "
212+ return data
213+
214+ def generate (self , data : dict ) -> dict :
215+ seed = data .get ("seed" )
216+ if seed in self ._fail :
217+ raise ValueError (f"non-retryable on seed={ seed } " )
218+ data [self .config .name ] = f"ok_{ seed } "
219+ return data
220+
221+
176222class MockRetryableErrorGenerator (ColumnGenerator [ExpressionColumnConfig ]):
177223 """Generator that raises a parametrizable retryable error then succeeds."""
178224
@@ -722,6 +768,141 @@ async def test_scheduler_error_rate_shutdown() -> None:
722768
723769 # Early shutdown: not all rows should be checkpointed (some row groups incomplete)
724770 assert buffer_mgr .actual_num_records < 10
771+ # No leftover unfinished row groups (finalize-after-shutdown drains them).
772+ assert not scheduler ._rg_states
773+
774+
775+ @pytest .mark .asyncio (loop_scope = "session" )
776+ async def test_partial_row_group_salvaged_after_early_shutdown () -> None :
777+ """Mid-run shutdown drops incomplete rows and checkpoints survivors."""
778+ provider = _mock_provider ()
779+ configs = [
780+ SamplerColumnConfig (name = "seed" , sampler_type = SamplerType .CATEGORY , params = {"values" : ["A" ]}),
781+ LLMTextColumnConfig (name = "cell_out" , prompt = "{{ seed }}" , model_alias = MODEL_ALIAS ),
782+ ]
783+ strategies = {
784+ "seed" : GenerationStrategy .FULL_COLUMN ,
785+ "cell_out" : GenerationStrategy .CELL_BY_CELL ,
786+ }
787+ # 3 succeed (0,1,2), 3 fail non-retryable (5,6,7), 4 stay in-flight (3,4,8,9)
788+ # until cancellation. Window=4, rate=0.5 → gate trips after ~3-5 outcomes.
789+ generators = {
790+ "seed" : MockSeedGenerator (config = _expr_config ("seed" ), resource_provider = provider ),
791+ "cell_out" : MockSelectiveFailGenerator (
792+ config = _expr_config ("cell_out" ),
793+ resource_provider = provider ,
794+ fail_on_seeds = {5 , 6 , 7 },
795+ slow_seeds = {3 , 4 , 8 , 9 },
796+ ),
797+ }
798+
799+ graph = ExecutionGraph .create (configs , strategies )
800+ row_groups = [(0 , 10 )]
801+ tracker = CompletionTracker .with_graph (graph , row_groups )
802+
803+ storage = MagicMock ()
804+ storage .dataset_name = "test"
805+ storage .get_file_paths .return_value = {}
806+ storage .write_batch_to_parquet_file .return_value = "/fake.parquet"
807+ storage .move_partial_result_to_final_file_path .return_value = "/fake_final.parquet"
808+ buffer_mgr = RowGroupBufferManager (storage )
809+
810+ finalized : list [int ] = []
811+
812+ def on_finalize (rg_id : int ) -> None :
813+ buffer_mgr .checkpoint_row_group (rg_id )
814+ finalized .append (rg_id )
815+
816+ scheduler = AsyncTaskScheduler (
817+ generators = generators ,
818+ graph = graph ,
819+ tracker = tracker ,
820+ row_groups = row_groups ,
821+ buffer_manager = buffer_mgr ,
822+ on_finalize_row_group = on_finalize ,
823+ shutdown_error_rate = 0.5 ,
824+ shutdown_error_window = 4 ,
825+ )
826+ await scheduler .run ()
827+
828+ assert scheduler .early_shutdown
829+ # The row group survived with the 3 fast successes; the in-flight rows were
830+ # cancelled and dropped by _finalize_after_shutdown.
831+ assert 0 in finalized
832+ assert scheduler .partial_row_groups == (0 ,)
833+ # Exactly 3 rows survived (seeds 0, 1, 2).
834+ assert buffer_mgr .actual_num_records == 3
835+
836+
837+ @pytest .mark .asyncio (loop_scope = "session" )
838+ async def test_zero_survivor_shutdown_does_not_raise () -> None :
839+ """If every row is dropped at shutdown, the row group is freed without writing parquet."""
840+ provider = _mock_provider ()
841+ configs = [
842+ SamplerColumnConfig (name = "seed" , sampler_type = SamplerType .CATEGORY , params = {"values" : ["A" ]}),
843+ LLMTextColumnConfig (name = "cell_out" , prompt = "{{ seed }}" , model_alias = MODEL_ALIAS ),
844+ ]
845+ strategies = {
846+ "seed" : GenerationStrategy .FULL_COLUMN ,
847+ "cell_out" : GenerationStrategy .CELL_BY_CELL ,
848+ }
849+ # All 5 seeds fail non-retryable → all rows dropped before any can complete.
850+ generators = {
851+ "seed" : MockSeedGenerator (config = _expr_config ("seed" ), resource_provider = provider ),
852+ "cell_out" : MockSelectiveFailGenerator (
853+ config = _expr_config ("cell_out" ),
854+ resource_provider = provider ,
855+ fail_on_seeds = set (range (5 )),
856+ ),
857+ }
858+
859+ graph = ExecutionGraph .create (configs , strategies )
860+ row_groups = [(0 , 5 )]
861+ tracker = CompletionTracker .with_graph (graph , row_groups )
862+
863+ storage = MagicMock ()
864+ storage .dataset_name = "test"
865+ storage .get_file_paths .return_value = {}
866+ storage .write_batch_to_parquet_file .return_value = "/fake.parquet"
867+ storage .move_partial_result_to_final_file_path .return_value = "/fake_final.parquet"
868+ buffer_mgr = RowGroupBufferManager (storage )
869+
870+ finalized : list [int ] = []
871+
872+ def on_finalize (rg_id : int ) -> None :
873+ buffer_mgr .checkpoint_row_group (rg_id )
874+ finalized .append (rg_id )
875+
876+ scheduler = AsyncTaskScheduler (
877+ generators = generators ,
878+ graph = graph ,
879+ tracker = tracker ,
880+ row_groups = row_groups ,
881+ buffer_manager = buffer_mgr ,
882+ on_finalize_row_group = on_finalize ,
883+ shutdown_error_rate = 0.5 ,
884+ shutdown_error_window = 2 ,
885+ )
886+ # Must not raise (no FileNotFoundError, no DataDesignerGenerationError).
887+ await scheduler .run ()
888+
889+ assert scheduler .early_shutdown
890+ assert buffer_mgr .actual_num_records == 0
891+ # All rows dropped → checkpoint path frees buffer without writing; on_finalize
892+ # is *not* called because every row was dropped before survivors could exist.
893+ assert finalized == []
894+ # No partial-row-groups recorded — there were no incomplete-but-not-dropped rows.
895+ assert scheduler .partial_row_groups == ()
896+ storage .write_batch_to_parquet_file .assert_not_called ()
897+
898+
899+ @pytest .mark .asyncio (loop_scope = "session" )
900+ async def test_healthy_run_has_no_partial_signal () -> None :
901+ """Successful run leaves early_shutdown=False and partial_row_groups empty."""
902+ scheduler , _tracker = _build_simple_pipeline (num_records = 3 )
903+ await scheduler .run ()
904+ assert not scheduler .early_shutdown
905+ assert scheduler .partial_row_groups == ()
725906
726907
727908@pytest .mark .asyncio (loop_scope = "session" )
0 commit comments