@@ -919,6 +919,119 @@ async def test_retryable_errors_do_not_trigger_early_shutdown(
919919 assert tracker .is_row_group_complete (0 , 10 , ["seed" , "col" ])
920920
921921
922+ @pytest .mark .asyncio (loop_scope = "session" )
923+ async def test_degraded_provider_warn_fires_above_threshold (caplog : pytest .LogCaptureFixture ) -> None :
924+ """When >= threshold of recent outcomes are retryable errors, a WARN log fires."""
925+ provider = _mock_provider ()
926+ configs = [
927+ SamplerColumnConfig (name = "seed" , sampler_type = SamplerType .CATEGORY , params = {"values" : ["A" ]}),
928+ LLMTextColumnConfig (name = "col" , prompt = "{{ seed }}" , model_alias = MODEL_ALIAS ),
929+ ]
930+ strategies = {
931+ "seed" : GenerationStrategy .FULL_COLUMN ,
932+ "col" : GenerationStrategy .CELL_BY_CELL ,
933+ }
934+ # 6 retryable failures across 10 cells + their successful retries → ~6/16 retryable.
935+ # Set window to 8 and threshold to 0.5 so the WARN can fire.
936+ generators = {
937+ "seed" : MockSeedGenerator (config = _expr_config ("seed" ), resource_provider = provider ),
938+ "col" : MockRetryableErrorGenerator (
939+ config = _expr_config ("col" ),
940+ resource_provider = provider ,
941+ error_factory = lambda : ModelTimeoutError ("read timeout" ),
942+ retryable_failures = 6 ,
943+ ),
944+ }
945+
946+ graph = ExecutionGraph .create (configs , strategies )
947+ row_groups = [(0 , 10 )]
948+ tracker = CompletionTracker .with_graph (graph , row_groups )
949+
950+ storage = MagicMock ()
951+ storage .dataset_name = "test"
952+ storage .get_file_paths .return_value = {}
953+ storage .write_batch_to_parquet_file .return_value = "/fake.parquet"
954+ storage .move_partial_result_to_final_file_path .return_value = "/fake_final.parquet"
955+ buffer_mgr = RowGroupBufferManager (storage )
956+
957+ scheduler = AsyncTaskScheduler (
958+ generators = generators ,
959+ graph = graph ,
960+ tracker = tracker ,
961+ row_groups = row_groups ,
962+ buffer_manager = buffer_mgr ,
963+ degraded_warn_rate = 0.5 ,
964+ degraded_warn_window = 8 ,
965+ degraded_warn_interval_s = 0.0 ,
966+ )
967+ with caplog .at_level ("WARNING" ):
968+ await scheduler .run ()
969+
970+ degraded_msgs = [r for r in caplog .records if "degraded performance" in r .getMessage ()]
971+ assert degraded_msgs , "expected a 'degraded performance' WARN to be emitted"
972+
973+
974+ @pytest .mark .asyncio (loop_scope = "session" )
975+ async def test_degraded_provider_warn_throttled (caplog : pytest .LogCaptureFixture ) -> None :
976+ """Successive degraded windows within the throttle interval emit only one WARN."""
977+ provider = _mock_provider ()
978+ configs = [
979+ SamplerColumnConfig (name = "seed" , sampler_type = SamplerType .CATEGORY , params = {"values" : ["A" ]}),
980+ LLMTextColumnConfig (name = "col" , prompt = "{{ seed }}" , model_alias = MODEL_ALIAS ),
981+ ]
982+ strategies = {
983+ "seed" : GenerationStrategy .FULL_COLUMN ,
984+ "col" : GenerationStrategy .CELL_BY_CELL ,
985+ }
986+ generators = {
987+ "seed" : MockSeedGenerator (config = _expr_config ("seed" ), resource_provider = provider ),
988+ "col" : MockRetryableErrorGenerator (
989+ config = _expr_config ("col" ),
990+ resource_provider = provider ,
991+ error_factory = lambda : ModelTimeoutError ("read timeout" ),
992+ retryable_failures = 8 ,
993+ ),
994+ }
995+
996+ graph = ExecutionGraph .create (configs , strategies )
997+ row_groups = [(0 , 12 )]
998+ tracker = CompletionTracker .with_graph (graph , row_groups )
999+
1000+ storage = MagicMock ()
1001+ storage .dataset_name = "test"
1002+ storage .get_file_paths .return_value = {}
1003+ storage .write_batch_to_parquet_file .return_value = "/fake.parquet"
1004+ storage .move_partial_result_to_final_file_path .return_value = "/fake_final.parquet"
1005+ buffer_mgr = RowGroupBufferManager (storage )
1006+
1007+ scheduler = AsyncTaskScheduler (
1008+ generators = generators ,
1009+ graph = graph ,
1010+ tracker = tracker ,
1011+ row_groups = row_groups ,
1012+ buffer_manager = buffer_mgr ,
1013+ degraded_warn_rate = 0.5 ,
1014+ degraded_warn_window = 4 ,
1015+ degraded_warn_interval_s = 3600.0 ,
1016+ )
1017+ with caplog .at_level ("WARNING" ):
1018+ await scheduler .run ()
1019+
1020+ degraded_msgs = [r for r in caplog .records if "degraded performance" in r .getMessage ()]
1021+ assert len (degraded_msgs ) == 1 , f"expected exactly one throttled WARN, got { len (degraded_msgs )} "
1022+
1023+
1024+ @pytest .mark .asyncio (loop_scope = "session" )
1025+ async def test_degraded_provider_warn_silent_under_threshold (caplog : pytest .LogCaptureFixture ) -> None :
1026+ """Healthy runs (no errors) never emit the degraded-provider WARN."""
1027+ scheduler , _tracker = _build_simple_pipeline (num_records = 5 )
1028+ with caplog .at_level ("WARNING" ):
1029+ await scheduler .run ()
1030+
1031+ degraded_msgs = [r for r in caplog .records if "degraded performance" in r .getMessage ()]
1032+ assert not degraded_msgs
1033+
1034+
9221035@pytest .mark .asyncio (loop_scope = "session" )
9231036async def test_scheduler_on_before_checkpoint_callback () -> None :
9241037 """on_before_checkpoint is called before each row group is checkpointed."""
0 commit comments