diff --git a/tests/scheduler/test_dp_scheduler.py b/tests/scheduler/test_dp_scheduler.py new file mode 100644 index 00000000000..dc628e6ade9 --- /dev/null +++ b/tests/scheduler/test_dp_scheduler.py @@ -0,0 +1,666 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import unittest +from multiprocessing import Queue +from unittest.mock import Mock, patch + +# Mock all external dependencies before importing anything +mock_logger = Mock() + + +# Create a proper mock for FD_EP_BATCHED_TOKEN_TIMEOUT that can be compared with float +class MockEnv: + FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1 + + +mock_envs = MockEnv() + +# Mock threading module to prevent real thread creation +import threading + +mock_threading = Mock() +sys.modules["threading"] = mock_threading +mock_threading.Thread = Mock() +mock_threading.Lock = Mock(return_value=Mock()) +mock_threading.Condition = Mock(return_value=Mock()) + +# Create mock modules +sys.modules["fastdeploy"] = Mock() +sys.modules["fastdeploy.utils"] = Mock() +sys.modules["fastdeploy.envs"] = mock_envs +sys.modules["fastdeploy.engine"] = Mock() +sys.modules["fastdeploy.engine.request"] = Mock() +sys.modules["fastdeploy.scheduler"] = Mock() +sys.modules["fastdeploy.scheduler.local_scheduler"] = Mock() +sys.modules["fastdeploy.scheduler.data"] = Mock() + +# Mock the get_logger function +sys.modules["fastdeploy.utils"].get_logger = Mock(return_value=mock_logger) + + +# Mock the Request, RequestOutput, and ScheduledResponse classes +class MockRequest: + def __init__(self, request_id, prompt_tokens_ids_len=10): + self.request_id = request_id + self.prompt_tokens_ids_len = prompt_tokens_ids_len + self.schedule_time = time.time() + self.raw = self + + +class MockRequestOutput: + def __init__(self, request_id, finished=False): + self.request_id = request_id + self.finished = finished + + +class MockScheduledResponse: + def __init__(self, request_output): + self.request_id = request_output.request_id + self.finished = request_output.finished + + +# Mock LocalScheduler base class +class MockLocalScheduler: + def __init__( + self, + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + ): + self.max_size = max_size + self.ttl = ttl + self.mutex = threading.Lock() + self.requests = {} + self.responses = {} + self.ids = [] + self.ids_read_cursor = 0 + self.requests_not_empty = threading.Condition() + self.responses_not_empty = threading.Condition() + + def calc_required_blocks(self, token_len, block_size): + return (token_len + block_size - 1) // block_size + + def put_requests(self, requests): + with self.mutex: + for request in requests: + if request.request_id not in self.requests: + self.requests[request.request_id] = request + self.ids.append(request.request_id) + with self.requests_not_empty: + self.requests_not_empty.notify_all() + + def get_results(self): + with self.responses_not_empty: + # Don't actually wait, just check if there are responses + if any(self.responses.values()): + results = [] + for response_list in list(self.responses.values()): + results.extend(response_list) + self.responses.clear() + return results + return [] + + def _recycle(self, request_id=None): + """Mock implementation of _recycle method.""" + if request_id is not None: + self.requests.pop(request_id, None) + self.responses.pop(request_id, None) + if hasattr(self, "splitwise_role") and self.splitwise_role == "decode": + return + if request_id in self.ids: + self.ids.pop(self.ids.index(request_id)) + self.ids_read_cursor = max(0, self.ids_read_cursor - 1) + return + + if self.max_size <= 0: + return + + if len(self.requests) <= self.max_size: + return + + now = time.time() + expired_ids = [] + for req_id in self.ids: + if req_id in self.requests: + request = self.requests[req_id] + if now - request.schedule_time >= self.ttl: + expired_ids.append(req_id) + else: + break + + for expired_id in expired_ids: + self.requests.pop(expired_id, None) + self.responses.pop(expired_id, None) + if expired_id in self.ids: + self.ids.pop(self.ids.index(expired_id)) + + if len(expired_ids) > 0: + self.ids_read_cursor = max(0, self.ids_read_cursor - len(expired_ids)) + + +# Set up the mock classes in the modules +sys.modules["fastdeploy.engine.request"].Request = MockRequest +sys.modules["fastdeploy.engine.request"].RequestOutput = MockRequestOutput +sys.modules["fastdeploy.scheduler.data"].ScheduledResponse = MockScheduledResponse +sys.modules["fastdeploy.scheduler.local_scheduler"].LocalScheduler = MockLocalScheduler + +# Now we can import the dp_scheduler module with all dependencies mocked +import importlib.util +import os + +spec = importlib.util.spec_from_file_location( + "dp_scheduler", os.path.join(os.path.dirname(__file__), "../../fastdeploy/scheduler/dp_scheduler.py") +) +dp_scheduler_module = importlib.util.module_from_spec(spec) + +# Mock the dependencies in the module +dp_scheduler_module.envs = mock_envs +dp_scheduler_module.get_logger = Mock(return_value=mock_logger) +dp_scheduler_module.threading = mock_threading # Add threading to the module + +# Execute the module +spec.loader.exec_module(dp_scheduler_module) + +# Extract the classes we want to test +DPLocalScheduler = dp_scheduler_module.DPLocalScheduler +DPScheduler = dp_scheduler_module.DPScheduler + +# Override the scheduler_logger to use our mock +original_init = DPLocalScheduler.__init__ + + +def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.scheduler_logger = mock_logger + + +DPLocalScheduler.__init__ = patched_init + + +class TestDPLocalScheduler(unittest.TestCase): + """Test cases for DPLocalScheduler class.""" + + def setUp(self): + """Set up test fixtures.""" + self.scheduler = DPLocalScheduler( + max_size=100, + ttl=60, + enable_chunked_prefill=True, + max_num_partial_prefills=4, + max_long_partial_prefills=2, + long_prefill_token_threshold=1024, + splitwise_role="prefill", + ) + + def test_initialization_with_default_role(self): + """Test scheduler initialization with default splitwise_role.""" + scheduler = DPLocalScheduler( + max_size=50, + ttl=30, + enable_chunked_prefill=False, + max_num_partial_prefills=2, + max_long_partial_prefills=1, + long_prefill_token_threshold=512, + ) + self.assertEqual(scheduler.splitwise_role, "prefill") + self.assertEqual(scheduler.max_size, 50) + self.assertEqual(scheduler.ttl, 30) + + def test_initialization_with_custom_role(self): + """Test scheduler initialization with custom splitwise_role.""" + scheduler = DPLocalScheduler( + max_size=50, + ttl=30, + enable_chunked_prefill=False, + max_num_partial_prefills=2, + max_long_partial_prefills=1, + long_prefill_token_threshold=512, + splitwise_role="decode", + ) + self.assertEqual(scheduler.splitwise_role, "decode") + + def test_put_results_with_finished_requests(self): + """Test putting results with finished requests.""" + # Reset mock logger + mock_logger.reset_mock() + + # Create mock request outputs + results = [ + MockRequestOutput("req1", finished=True), + MockRequestOutput("req2", finished=False), + MockRequestOutput("req3", finished=True), + ] + + # Put results - this should work without threading issues since we're using the real implementation + with patch.object(self.scheduler, "responses_not_empty"): + self.scheduler.put_results(results) + + # Check that finished requests were logged - the logger should have been called + self.assertTrue(mock_logger.info.called) + # Get the actual call arguments to verify the message format + call_args = mock_logger.info.call_args[0][0] + self.assertIn("finished responses", call_args) + self.assertIn("req1", call_args) + self.assertIn("req3", call_args) + + def test_put_results_with_new_responses(self): + """Test putting results with new responses.""" + results = [MockRequestOutput("new_req", finished=False)] + + # Initially no responses + self.assertNotIn("new_req", self.scheduler.responses) + + # Put results - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "responses_not_empty"): + self.scheduler.put_results(results) + + # Check response was added + self.assertIn("new_req", self.scheduler.responses) + self.assertEqual(len(self.scheduler.responses["new_req"]), 1) + + def test_put_results_with_existing_responses(self): + """Test putting results with existing responses.""" + results1 = [MockRequestOutput("existing_req", finished=False)] + results2 = [MockRequestOutput("existing_req", finished=True)] + + # Put first set of results - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "responses_not_empty"): + self.scheduler.put_results(results1) + self.assertEqual(len(self.scheduler.responses["existing_req"]), 1) + + # Put second set of results + self.scheduler.put_results(results2) + self.assertEqual(len(self.scheduler.responses["existing_req"]), 2) + + def test_recycle_specific_request_id(self): + """Test recycling a specific request ID.""" + # Add some test data + self.scheduler.requests["req1"] = MockRequest("req1") + self.scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))] + self.scheduler.ids = ["req1", "req2"] + self.scheduler.ids_read_cursor = 1 + + # Recycle specific request + self.scheduler._recycle("req1") + + # Verify request was removed + self.assertNotIn("req1", self.scheduler.requests) + self.assertNotIn("req1", self.scheduler.responses) + self.assertEqual(self.scheduler.ids, ["req2"]) + self.assertEqual(self.scheduler.ids_read_cursor, 0) + + def test_recycle_specific_request_id_decode_role(self): + """Test recycling a specific request ID in decode role.""" + scheduler = DPLocalScheduler( + max_size=100, + ttl=60, + enable_chunked_prefill=True, + max_num_partial_prefills=4, + max_long_partial_prefills=2, + long_prefill_token_threshold=1024, + splitwise_role="decode", + ) + + # Add some test data + scheduler.requests["req1"] = MockRequest("req1") + scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))] + scheduler.ids = ["req1", "req2"] + scheduler.ids_read_cursor = 1 + + # Recycle specific request (should not modify ids in decode role) + scheduler._recycle("req1") + + # Verify request and response were removed but ids unchanged + self.assertNotIn("req1", scheduler.requests) + self.assertNotIn("req1", scheduler.responses) + self.assertEqual(scheduler.ids, ["req1", "req2"]) # Should not change in decode role + self.assertEqual(scheduler.ids_read_cursor, 1) # Should not change in decode role + + def test_recycle_with_max_size_zero(self): + """Test recycling when max_size is 0 (unlimited).""" + scheduler = DPLocalScheduler( + max_size=0, + ttl=60, + enable_chunked_prefill=True, + max_num_partial_prefills=4, + max_long_partial_prefills=2, + long_prefill_token_threshold=1024, + ) + + # Add test data + scheduler.requests["req1"] = MockRequest("req1") + scheduler.responses["req1"] = [MockScheduledResponse(MockRequestOutput("req1"))] + scheduler.ids = ["req1"] + + # Should return early without recycling + scheduler._recycle() + + # Data should remain unchanged + self.assertIn("req1", scheduler.requests) + self.assertIn("req1", scheduler.responses) + + def test_recycle_under_max_size(self): + """Test recycling when under max_size limit.""" + # Add test data under limit + self.scheduler.requests["req1"] = MockRequest("req1") + self.scheduler.requests["req2"] = MockRequest("req2") + self.scheduler.ids = ["req1", "req2"] + + # Should return early without recycling + self.scheduler._recycle() + + # Data should remain unchanged + self.assertIn("req1", self.scheduler.requests) + self.assertIn("req2", self.scheduler.requests) + + @patch("time.time") + def test_recycle_expired_requests(self, mock_time): + """Test recycling expired requests.""" + # Create a scheduler with smaller max_size to trigger recycling + scheduler = DPLocalScheduler( + max_size=1, # Set to 1 to trigger recycling when we have 2 requests + ttl=60, + enable_chunked_prefill=True, + max_num_partial_prefills=4, + max_long_partial_prefills=2, + long_prefill_token_threshold=1024, + ) + + # Mock time to make requests appear expired + mock_time.return_value = 100.0 + + # Create expired request (schedule_time = 50.0, ttl = 60, so expired) + expired_request = MockRequest("expired_req") + expired_request.schedule_time = 30.0 # 70 seconds ago (beyond ttl=60) + + # Create non-expired request + fresh_request = MockRequest("fresh_req") + fresh_request.schedule_time = 80.0 # 20 seconds ago (within ttl=60) + + # Add test data + scheduler.requests["expired_req"] = expired_request + scheduler.requests["fresh_req"] = fresh_request + scheduler.ids = ["expired_req", "fresh_req"] + scheduler.ids_read_cursor = 2 + + # Recycle expired requests + scheduler._recycle() + + # Verify expired request was removed, fresh request remains + self.assertNotIn("expired_req", scheduler.requests) + self.assertIn("fresh_req", scheduler.requests) + self.assertEqual(scheduler.ids, ["fresh_req"]) + self.assertEqual(scheduler.ids_read_cursor, 1) + + def test_get_requests_insufficient_resources(self): + """Test getting requests when resources are insufficient.""" + mock_logger.reset_mock() + + # Test with insufficient blocks - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + self.assertEqual(requests, []) + # The logger should have been called for insufficient resources + self.assertTrue(mock_logger.debug.called) + # Check the message contains expected content + call_args = mock_logger.debug.call_args[0][0] + self.assertIn("insufficient", call_args.lower()) + + def test_get_requests_insufficient_batch(self): + """Test getting requests when batch size is insufficient.""" + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0 + ) + + self.assertEqual(requests, []) + + @patch("time.time") + @patch.object(dp_scheduler_module, "envs") + def test_get_requests_no_requests_available(self, mock_envs, mock_time): + """Test getting requests when no requests are available.""" + # Mock envs to return our mock environment + mock_envs.FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1 + + # Mock time to return consistent values - provide enough values for multiple calls + time_values = [100.0, 100.1, 100.2, 100.3, 100.4, 100.5] # Multiple values for the loop + mock_time.side_effect = time_values + + # Mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + # Should return empty list after timeout + self.assertEqual(requests, []) + + def test_get_requests_successful_batching(self): + """Test successful request batching.""" + # Add a mock request + mock_request = MockRequest("test_req", prompt_tokens_ids_len=10) + self.scheduler.requests["test_req"] = mock_request + self.scheduler.ids = ["test_req"] + + # Mock calc_required_blocks to return small value + self.scheduler.calc_required_blocks = Mock(return_value=1) + + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + # Should get the request + self.assertEqual(len(requests), 1) + self.assertEqual(requests[0].request_id, "test_req") + + @patch("time.time") + @patch.object(dp_scheduler_module, "envs") + def test_get_requests_timeout(self, mock_envs, mock_time): + """Test request batching with timeout.""" + # Mock envs to return our mock environment + mock_envs.FD_EP_BATCHED_TOKEN_TIMEOUT = 0.1 + + # Mock time to return consistent values - provide enough values for multiple calls + time_values = [100.0, 100.1, 100.2, 100.3, 100.4, 100.5] # Multiple values for the loop + mock_time.side_effect = time_values + + # Add a mock request + mock_request = MockRequest("test_req", prompt_tokens_ids_len=10) + self.scheduler.requests["test_req"] = mock_request + self.scheduler.ids = ["test_req"] + + # Mock calc_required_blocks to return large value to exceed available blocks + self.scheduler.calc_required_blocks = Mock(return_value=50) + + # Mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + # Should return empty due to timeout + self.assertEqual(requests, []) + + +class TestDPScheduler(unittest.TestCase): + """Test cases for DPScheduler class.""" + + def setUp(self): + """Set up test fixtures.""" + self.dp_scheduler = DPScheduler( + max_size=100, + ttl=60, + enable_chunked_prefill=True, + max_num_partial_prefills=4, + max_long_partial_prefills=2, + long_prefill_token_threshold=1024, + splitwise_role="prefill", + ) + + def test_initialization(self): + """Test DPScheduler initialization.""" + self.assertIsNotNone(self.dp_scheduler._scheduler) + self.assertEqual(self.dp_scheduler._scheduler.splitwise_role, "prefill") + + def test_get_unhandled_request_num(self): + """Test getting number of unhandled requests.""" + # Initially should be 0 + self.assertEqual(self.dp_scheduler.get_unhandled_request_num(), 0) + + # Add a request to the internal scheduler + mock_request = MockRequest("test_req") + self.dp_scheduler._scheduler.requests["test_req"] = mock_request + + # Should return 1 + self.assertEqual(self.dp_scheduler.get_unhandled_request_num(), 1) + + def test_put_results(self): + """Test putting results to DPScheduler.""" + results = [MockRequestOutput("test_req", finished=True)] + + # Should not raise an exception - mock the condition variable to avoid threading issues + with patch.object(self.dp_scheduler._scheduler, "responses_not_empty"): + self.dp_scheduler.put_results(results) + + # Verify results were added to the internal scheduler + self.assertIn("test_req", self.dp_scheduler._scheduler.responses) + + def test_get_requests_delegates_to_scheduler(self): + """Test that get_requests delegates to internal scheduler.""" + # Mock the internal scheduler's get_requests method + expected_requests = [MockRequest("test_req")] + self.dp_scheduler._scheduler.get_requests = Mock(return_value=expected_requests) + + requests = self.dp_scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + # Verify delegation + self.dp_scheduler._scheduler.get_requests.assert_called_once_with(20, 16, 10, 1024, 1) + self.assertEqual(requests, expected_requests) + + def test_put_requests_missing_dp_rank(self): + """Test put_requests raises error when dp_rank is missing.""" + # Create a request without dp_rank attribute + mock_request = MockRequest("test_req") + + requests = [mock_request] + + # Should raise ValueError + with self.assertRaises(ValueError) as cm: + self.dp_scheduler.put_requests(requests) + + self.assertIn("missing the 'dp_rank' attribute", str(cm.exception)) + + @patch("threading.Thread") + def test_put_requests_success(self, mock_thread): + """Test successful put_requests with dp_rank.""" + # Create request queues - use Mock instead of real Queue to avoid threading issues + request_queues = [Mock(), Mock(), Mock()] + result_queue = Mock() + + # Start the scheduler - this will create mocked threads + self.dp_scheduler.start(0, request_queues, result_queue) + + # Create requests with dp_rank + mock_request1 = MockRequest("test_req1") + mock_request1.dp_rank = 0 + mock_request2 = MockRequest("test_req2") + mock_request2.dp_rank = 1 + + requests = [mock_request1, mock_request2] + + # Should not raise an exception + results = self.dp_scheduler.put_requests(requests) + + # Verify results format + self.assertEqual(len(results), 2) + self.assertEqual(results[0], ("test_req1", None)) + self.assertEqual(results[1], ("test_req2", None)) + + # Verify requests were put to the correct queues + request_queues[0].put.assert_called_once_with(mock_request1) + request_queues[1].put.assert_called_once_with(mock_request2) + + @patch("threading.Thread") + def test_start_creates_threads(self, mock_thread): + """Test that start creates and starts threads.""" + mock_thread.return_value = Mock() + + request_queues = [Queue(), Queue()] + result_queue = Queue() + + self.dp_scheduler.start(0, request_queues, result_queue) + + # Should create 2 threads + self.assertEqual(mock_thread.call_count, 2) + + # Both threads should be started + mock_thread.return_value.start.assert_called() + + +class TestDPIntegration(unittest.TestCase): + """Integration tests for DP Scheduler functionality.""" + + def test_end_to_end_request_flow(self): + """Test end-to-end request flow through DP scheduler - without real threads.""" + # Create DP scheduler + dp_scheduler = DPScheduler( + max_size=10, + ttl=30, + enable_chunked_prefill=True, + max_num_partial_prefills=2, + max_long_partial_prefills=1, + long_prefill_token_threshold=512, + ) + + # Mock the start method to avoid creating real threads + with patch.object(dp_scheduler, "start") as mock_start: + # Set up test data directly + dp_scheduler.dp_rank = 0 + dp_scheduler.request_queues = [Mock(), Mock()] + dp_scheduler.result_queue = Mock() + dp_scheduler.scheduler_logger = mock_logger + dp_scheduler._scheduler.scheduler_logger = mock_logger + + # Test basic functionality without real threads + mock_request = MockRequest("integration_req") + mock_request.dp_rank = 0 + + # Mock the request_queues to avoid real Queue operations + dp_scheduler.request_queues[0].put = Mock() + + # Test put_requests functionality + results = dp_scheduler.put_requests([mock_request]) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], ("integration_req", None)) + + # Verify the request was put to the correct queue + dp_scheduler.request_queues[0].put.assert_called_once_with(mock_request) + + # Verify start method was not called (to avoid threads) + mock_start.assert_not_called() + + +if __name__ == "__main__": + unittest.main(verbosity=2)