[CI]【Hackathon 9th Sprint No.12】功能模块 fastdeploy/spec_decode/mtp.py 单测补充#5533
[CI]【Hackathon 9th Sprint No.12】功能模块 fastdeploy/spec_decode/mtp.py 单测补充#5533CSWYF3634076 merged 10 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive unit tests for the MTP (Multi-Token-Prediction) proposer module (fastdeploy/spec_decode/mtp.py) as part of Hackathon 9th Sprint. The tests use extensive mocking to verify various initialization, configuration, and runtime behaviors of the MTPProposer class.
- Adds 498 lines of unit tests covering MTPProposer initialization, configuration updates, KV cache management, and various runtime operations
- Tests multiple scenarios including different request types (PREFILL, DECODE, PREEMPTED), chunked prefill, expert parallel, and multimodal inputs
- Uses unittest.mock extensively to isolate the component under test
| @@ -0,0 +1,526 @@ | |||
| """ | |||
There was a problem hiding this comment.
The PR title format is correct with the [CI] tag, but it should be translated to English to match standard GitHub PR practices. The Chinese title should be translated to something like: "[CI][Hackathon 9th Sprint No.12] Add unit tests for fastdeploy/spec_decode/mtp.py"
| self.assertEqual(proposer.main_model_num_gpu_blocks, 20) | ||
| self.assertIn("free_list", proposer.model_inputs) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test insert_tasks_v1 with different request types""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test with PREFILL request | ||
| request1 = Request( | ||
| request_id="test1", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2, 3, 4, 5], | ||
| prompt_token_ids_len=5, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request1.idx = 0 | ||
| request1.task_type = RequestType.PREFILL | ||
| request1.prefill_start_index = 0 | ||
| request1.prefill_end_index = 5 | ||
| request1.output_token_ids = [] | ||
| request1.block_tables = [0, 1] | ||
|
|
||
| # Test with DECODE request | ||
| request2 = Request( | ||
| request_id="test2", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_token_ids_len=2, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request2.idx = 1 | ||
| request2.task_type = RequestType.DECODE | ||
| request2.block_tables = [2, 3] | ||
|
|
||
| # Test with PREEMPTED request | ||
| request3 = Request( | ||
| request_id="test3", | ||
| prompt="test", | ||
| prompt_token_ids=[1], | ||
| prompt_token_ids_len=1, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request3.idx = 0 | ||
| request3.task_type = RequestType.PREEMPTED | ||
|
|
||
| # Test splitwise_role == "decode" | ||
| self.fd_config.scheduler_config.splitwise_role = "decode" |
There was a problem hiding this comment.
The test creates a Request object with numerous parameters but doesn't validate that the MTPProposer correctly processes all these fields. Consider adding assertions to verify that the proposer correctly handles request.idx, request.block_tables, request.draft_token_ids, etc., after insertion.
| def test_extend_draft_token_and_run_impl(self, mock_ngram, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test _extend_draft_token_with_ngram_match and _run_impl""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) |
There was a problem hiding this comment.
The test uses patch.object to mock internal methods but doesn't verify these mocked methods were called with the correct arguments. Consider adding assertions like mock_method.assert_called_once_with(expected_args) to ensure the implementation flow is correct.
| patch.object(proposer, "_update_status"), | ||
| ): | ||
| proposer._run_impl(full_hidden_states) | ||
|
|
There was a problem hiding this comment.
The test for _empty_cache only verifies that paddle.device.cuda.empty_cache is called but doesn't test the actual cache clearing behavior or side effects. Consider verifying that the method actually affects the cache state or memory usage in a meaningful way.
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_init_and_config_methods(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test initialization and config update methods""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn_backend.return_value.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test _update_mtp_config | ||
| self.assertEqual(proposer.model_config.architectures[0], "ErnieMTPForCausalLM") | ||
| self.assertEqual(proposer.model_config.num_hidden_layers, 1) | ||
| self.assertEqual(proposer.speculative_config.model_type, "mtp") | ||
|
|
||
| # Test _get_cache_type | ||
| cache_type = proposer._get_cache_type() | ||
| self.assertIn(cache_type, ["uint8", "int8"]) | ||
|
|
||
| # Test is_chunk_prefill_enabled | ||
| self.assertTrue(proposer.is_chunk_prefill_enabled()) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_dummy_prefill_inputs_and_kv_cache(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test dummy_prefill_inputs and initialize_kv_cache with different branches""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test dummy_prefill_inputs with expert parallel | ||
| self.fd_config.parallel_config.enable_expert_parallel = True | ||
| proposer.dummy_prefill_inputs(num_tokens=100, batch_size=2, expected_decode_len=10) | ||
| self.assertGreater(proposer.model_inputs["seq_lens_encoder"][0].item(), 0) | ||
|
|
||
| # Test initialize_kv_cache with prefix caching | ||
| self.fd_config.cache_config.enable_prefix_caching = True | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False) | ||
| self.assertIn("caches", proposer.model_inputs) | ||
|
|
||
| # Test initialize_kv_cache with block_wise_fp8 | ||
| self.fd_config.quant_config = QuantizationConfig({}) | ||
| self.fd_config.quant_config.kv_cache_quant_type = "block_wise_fp8" | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False) | ||
|
|
||
| # Test initialize_kv_cache with profile=True | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=True) | ||
|
|
||
| # Test clear_mtp_cache | ||
| proposer.clear_mtp_cache() | ||
| self.assertNotIn("caches", proposer.model_inputs) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_update_mtp_block_num(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test update_mtp_block_num""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
| proposer.update_mtp_block_num(num_gpu_blocks=20) | ||
| self.assertEqual(proposer.main_model_num_gpu_blocks, 20) | ||
| self.assertIn("free_list", proposer.model_inputs) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test insert_tasks_v1 with different request types""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test with PREFILL request | ||
| request1 = Request( | ||
| request_id="test1", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2, 3, 4, 5], | ||
| prompt_token_ids_len=5, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request1.idx = 0 | ||
| request1.task_type = RequestType.PREFILL | ||
| request1.prefill_start_index = 0 | ||
| request1.prefill_end_index = 5 | ||
| request1.output_token_ids = [] | ||
| request1.block_tables = [0, 1] | ||
|
|
||
| # Test with DECODE request | ||
| request2 = Request( | ||
| request_id="test2", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_token_ids_len=2, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request2.idx = 1 | ||
| request2.task_type = RequestType.DECODE | ||
| request2.block_tables = [2, 3] | ||
|
|
||
| # Test with PREEMPTED request | ||
| request3 = Request( | ||
| request_id="test3", | ||
| prompt="test", | ||
| prompt_token_ids=[1], | ||
| prompt_token_ids_len=1, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request3.idx = 0 | ||
| request3.task_type = RequestType.PREEMPTED | ||
|
|
||
| # Test splitwise_role == "decode" | ||
| self.fd_config.scheduler_config.splitwise_role = "decode" | ||
| proposer.insert_tasks_v1([request1], 1) | ||
|
|
||
| # Test with multimodal | ||
| proposer.enable_mm = True | ||
| request1.multimodal_inputs = {"attention_mask_offset": [0, 1, 2, 3, 4]} | ||
| proposer.model_inputs["attn_mask_offsets_full"] = paddle.zeros([2, 2048], dtype="int32") | ||
| proposer.model_inputs["attn_mask_offsets_decoder"] = paddle.zeros([2, 1], dtype="int32") | ||
| proposer.insert_tasks_v1([request1], 1) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_insert_prefill_inputs(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test insert_prefill_inputs with different roles and chunked prefill""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| request = Request( | ||
| request_id="test", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2, 3, 4, 5], | ||
| prompt_token_ids_len=5, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request.idx = 0 | ||
| request.block_tables = [0, 1] | ||
| request.draft_token_ids = [10, 11] | ||
|
|
||
| # Test with prefill role | ||
| request.disaggregate_info = {"role": "prefill"} | ||
| proposer.insert_prefill_inputs([request], 1) | ||
| self.assertEqual(proposer.role, "prefill") | ||
|
|
||
| # Test with decode role | ||
| request.disaggregate_info = {"role": "decode"} | ||
| proposer.insert_prefill_inputs([request], 1) | ||
| self.assertEqual(proposer.role, "decode") | ||
|
|
||
| # Test with chunked prefill | ||
| self.fd_config.cache_config.enable_chunked_prefill = True | ||
| request.prefill_chunk_info = [3, 2] | ||
| request.disaggregate_info = None | ||
| proposer.insert_prefill_inputs([request], 1) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_forward_meta_and_exist_prefill(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test _initialize_forward_meta, _initialize_forward_meta_xpu, and exist_prefill""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10) | ||
|
|
||
| # Test _initialize_forward_meta | ||
| proposer._initialize_forward_meta(step_use_cudagraph=False) | ||
| self.assertIsNotNone(proposer.forward_meta) | ||
|
|
||
| # Test _initialize_forward_meta_xpu | ||
| proposer._initialize_forward_meta_xpu() | ||
| self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL") | ||
|
|
||
| # Test exist_prefill | ||
| proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")} | ||
| result = proposer.exist_prefill() | ||
| self.assertEqual(result, 1) | ||
|
|
||
| proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")} | ||
| result = proposer.exist_prefill() | ||
| self.assertEqual(result, 0) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| @patch("fastdeploy.spec_decode.mtp.draft_model_preprocess") | ||
| @patch("fastdeploy.spec_decode.mtp.eagle_get_hidden_states") | ||
| def test_prepare_inputs_and_post_process( | ||
| self, mock_eagle, mock_preprocess, mock_rope, mock_attn_backend, mock_model_loader | ||
| ): | ||
| """Test _prepare_inputs and _post_process""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
| mock_eagle.return_value = paddle.zeros([2, 768], dtype="bfloat16") | ||
| mock_preprocess.return_value = None | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
| full_hidden_states = paddle.zeros([2, 768], dtype="bfloat16") | ||
|
|
||
| # Test _prepare_inputs | ||
| proposer._prepare_inputs(full_hidden_states) | ||
| mock_preprocess.assert_called() | ||
| mock_eagle.assert_called() | ||
|
|
||
| # Test _post_process with prefill role | ||
| proposer.role = "prefill" | ||
| sampled_token_ids = paddle.ones([2, 1], dtype="int64") | ||
| proposer._post_process(sampled_token_ids) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_update_task_chunk_prefill(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test update_task_chunk_prefill""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| task = Mock() | ||
| task.idx = 0 | ||
| task.prefill_chunk_info = [3, 2, 1] | ||
| task.prompt_token_ids = [1, 2, 3, 4, 5, 6] | ||
|
|
||
| # Test chunk_idx == len(prefill_chunk_info) | ||
| task.chunk_idx = 3 | ||
| task.get = Mock(return_value=0) | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| # Test chunk_idx < len - 1 | ||
| task.chunk_idx = 0 | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| # Test last prefill | ||
| task.chunk_idx = 2 | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| @patch("fastdeploy.spec_decode.mtp.draft_model_postprocess") | ||
| @patch("fastdeploy.spec_decode.mtp.mtp_step_paddle") | ||
| def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test _update_status""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model |
There was a problem hiding this comment.
The test repeatedly patches the same three decorators (get_model_loader, get_attention_backend, get_rope) across multiple test methods. This is a code smell indicating duplication. Consider using a test class-level patcher or a setUp helper method that applies these patches once for all tests, improving maintainability and reducing boilerplate.
| self.target_model_inputs = { | ||
| "block_tables": paddle.zeros([2, 10], dtype="int32"), | ||
| "input_ids": paddle.zeros([2, 2048], dtype="int64"), | ||
| "seq_lens_this_time": paddle.zeros([2, 1], dtype="int32"), | ||
| "seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"), | ||
| "seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"), | ||
| "prompt_lens": paddle.zeros([2, 1], dtype="int64"), | ||
| "step_idx": paddle.zeros([2, 1], dtype="int64"), | ||
| "stop_flags": paddle.zeros([2, 1], dtype="bool"), | ||
| "stop_nums": paddle.zeros([2, 1], dtype="int32"), | ||
| "pre_ids": paddle.zeros([2, 2048], dtype="int64"), | ||
| "output_cum_offsets": paddle.zeros([2], dtype="int32"), | ||
| "output_padding_offset": paddle.zeros([2], dtype="int32"), | ||
| "ids_remove_padding": paddle.zeros([2], dtype="int64"), | ||
| "batch_id_per_token": paddle.zeros([2], dtype="int32"), | ||
| "cu_seqlens_q": paddle.zeros([3], dtype="int32"), | ||
| "cu_seqlens_k": paddle.zeros([3], dtype="int32"), | ||
| "decoder_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "decoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "decoder_num_blocks_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "decoder_num_blocks_device": paddle.zeros([2], dtype="int32"), | ||
| "decoder_chunk_size_device": paddle.zeros([2], dtype="int32"), | ||
| "max_len_tensor_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "encoder_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "encoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "encoder_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "kv_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "kv_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "kv_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "top_p": paddle.ones([2, 1], dtype="float32") * 0.9, | ||
| "top_k": paddle.zeros([2, 1], dtype="int32"), | ||
| "temperature": paddle.ones([2, 1], dtype="float32"), | ||
| "eos_token_id": paddle.ones([2], dtype="int64") * 2, | ||
| "penalty_score": paddle.ones([2, 1], dtype="float32"), | ||
| "frequency_score": paddle.zeros([2, 1], dtype="float32"), | ||
| "presence_score": paddle.zeros([2, 1], dtype="float32"), | ||
| "infer_seed": paddle.zeros([2, 1], dtype="int64"), | ||
| "max_dec_len": paddle.ones([2, 1], dtype="int64") * 512, | ||
| "min_dec_len": paddle.zeros([2, 1], dtype="int64"), | ||
| "bad_tokens": paddle.zeros([2], dtype="int64"), | ||
| "draft_tokens": paddle.zeros([2, 2], dtype="int64"), | ||
| "accept_tokens": paddle.zeros([2, 2], dtype="int64"), | ||
| "accept_num": paddle.ones([2], dtype="int32"), | ||
| "draft_logits": paddle.zeros([4, 32000], dtype="float32"), | ||
| "temp_scaled_logprobs": paddle.zeros([2], dtype="float32"), | ||
| "top_p_normalized_logprobs": paddle.zeros([2], dtype="float32"), | ||
| "encoder_block_lens": paddle.zeros([2, 1], dtype="int32"), | ||
| "cu_batch_token_offset": paddle.zeros([3], dtype="int32"), | ||
| "is_block_step": paddle.zeros([2], dtype="bool"), | ||
| } |
There was a problem hiding this comment.
The target_model_inputs dictionary is duplicated in the setUp method with extensive boilerplate. This large initialization block (lines 63-112) makes the setUp method hard to read and maintain. Consider extracting this into a helper method like _create_default_target_model_inputs() or moving it to the test utilities module for reuse across tests.
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn_backend.return_value.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) |
There was a problem hiding this comment.
The mock return values use hardcoded shapes like [2, 12, 16, 64] and [2, 32000] without explaining their meaning. These should either use named constants or include comments explaining what each dimension represents (e.g., batch_size=2, num_layers=12, seq_len=16, head_dim=64).
| mock_attn_backend.return_value = mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test with PREFILL request | ||
| request1 = Request( | ||
| request_id="test1", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2, 3, 4, 5], | ||
| prompt_token_ids_len=5, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request1.idx = 0 | ||
| request1.task_type = RequestType.PREFILL | ||
| request1.prefill_start_index = 0 | ||
| request1.prefill_end_index = 5 | ||
| request1.output_token_ids = [] | ||
| request1.block_tables = [0, 1] | ||
|
|
||
| # Test with DECODE request | ||
| request2 = Request( | ||
| request_id="test2", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_token_ids_len=2, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request2.idx = 1 | ||
| request2.task_type = RequestType.DECODE | ||
| request2.block_tables = [2, 3] | ||
|
|
||
| # Test with PREEMPTED request | ||
| request3 = Request( | ||
| request_id="test3", | ||
| prompt="test", | ||
| prompt_token_ids=[1], | ||
| prompt_token_ids_len=1, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| arrival_time=0.0, | ||
| ) | ||
| request3.idx = 0 | ||
| request3.task_type = RequestType.PREEMPTED | ||
|
|
||
| # Test splitwise_role == "decode" | ||
| self.fd_config.scheduler_config.splitwise_role = "decode" |
There was a problem hiding this comment.
The test creates Request objects with complex initialization but doesn't test edge cases like empty prompt_token_ids, None values, or invalid indices. Consider adding negative test cases to verify the proposer handles invalid inputs gracefully.
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test insert_tasks_v1 with different request types""" | ||
| mock_model = Mock() |
There was a problem hiding this comment.
The test uses the same mock setup repeatedly (mock_model with compute_logits, mock_attn with get_kv_cache_shape returning [2, 12, 16, 64], and mock_rope returning zeros). Consider extracting this into a helper method like _setup_common_mocks() to reduce duplication and improve test maintainability.
| # Test chunk_idx == len(prefill_chunk_info) | ||
| task.chunk_idx = 3 | ||
| task.get = Mock(return_value=0) | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| # Test chunk_idx < len - 1 | ||
| task.chunk_idx = 0 | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| # Test last prefill | ||
| task.chunk_idx = 2 | ||
| proposer.update_task_chunk_prefill(task) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| @patch("fastdeploy.spec_decode.mtp.draft_model_postprocess") | ||
| @patch("fastdeploy.spec_decode.mtp.mtp_step_paddle") |
There was a problem hiding this comment.
The test for _update_status only verifies that mtp_step_paddle is called but doesn't check the function was called with correct arguments or verify the state changes. Similarly, the test doesn't cover the ENABLE_V1_KVCACHE_SCHEDULER=True branch. Consider adding assertions for both branches and verifying the actual behavior.
d9ecb6f to
f839869
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #5533 +/- ##
==========================================
Coverage ? 61.31%
==========================================
Files ? 329
Lines ? 41158
Branches ? 6274
==========================================
Hits ? 25238
Misses ? 14030
Partials ? 1890
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…arrival_time, add missing keys
f74ca13 to
962d1e5
Compare
| ) | ||
|
|
||
| # Test with PREFILL request | ||
| request1 = Request( |
There was a problem hiding this comment.
这里有三个 request,1和3的idx重复,并且只insert了request1,这里是为了测试什么呢?
建议把三个 request 都 insert,然后判断 share_inputs 插入相关的数值是否正确
| self.fd_config.cache_config.enable_chunked_prefill = True | ||
| request.prefill_chunk_info = [3, 2] | ||
| request.disaggregate_info = None | ||
| proposer.insert_prefill_inputs([request], 1) |
| # Test _post_process with prefill role | ||
| proposer.role = "prefill" | ||
| sampled_token_ids = paddle.ones([2, 1], dtype="int64") | ||
| proposer._post_process(sampled_token_ids) |
There was a problem hiding this comment.
同上,这里只是执行,需要check下数值是否正确
| proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer | ||
|
|
||
| # Test with ENABLE_V1_KVCACHE_SCHEDULER=False | ||
| with patch("fastdeploy.spec_decode.mtp.envs.ENABLE_V1_KVCACHE_SCHEDULER", False): |
There was a problem hiding this comment.
你好,问题修改完了,同时ci也运行完成
|
@freeliuzc 麻烦再看下修改是否符合要求 |
|
0.1⭐️ |
…补充 (PaddlePaddle#5533) * Add unit tests for MTPProposer class in spec_decode/mtp.py * fix: remove non-existent QuantizationConfig import in test_mtp_proposer * fix: add logprobs_mode attribute to FakeModelConfig * fix: fix test failures in test_mtp_proposer - fix Mock setup, remove arrival_time, add missing keys * fix: add seq_lens_this_time initialization and kv_cache init before insert_tasks_v1 * fix: check pos_emb_type attribute existence before assertion * test: add minimal coverage for mtp cache type, mm init, preempted * test: fix cache_type_branches unsupported platform on 12 * test: refine MTPProposer tests for cache type, requests and chunked prefill * chore: remove stray spec_decode copy
Motivation
NO 12 fastdeploy/spec_decode/mtp.py 单测补充
develop分支,覆盖率56 ,miss行数199(40-50, 101, 104, 108, 135->137, 158, 196-197, 204, 207-233, 247-252, 286, 339->exit, 475, 498-505, 514, 552-558, 564-586, 597-674, 716-742, 748-751, 821-833, 870-882, 922, 946-948, 973-974, 989->996, 998, 1000->849, 1010-1091, 1094-1100, 1106-1129, 1143, 1161-1181, 1191, 1195, 1213-1214, 1219-1226)
当前pr,覆盖率81,miss行数 79
完成单测覆盖行数199-79=120
Modifications
add unittest tests/spec_decode/test_mtp_proposer.py
Usage or Command
no need
Accuracy Tests
no need
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.