Skip to content

[CI]【Hackathon 9th Sprint No.12】功能模块 fastdeploy/spec_decode/mtp.py 单测补充#5533

Merged
CSWYF3634076 merged 10 commits intoPaddlePaddle:developfrom
kesmeey:12
Dec 17, 2025
Merged

[CI]【Hackathon 9th Sprint No.12】功能模块 fastdeploy/spec_decode/mtp.py 单测补充#5533
CSWYF3634076 merged 10 commits intoPaddlePaddle:developfrom
kesmeey:12

Conversation

@kesmeey
Copy link
Collaborator

@kesmeey kesmeey commented Dec 12, 2025

Motivation

NO 12 fastdeploy/spec_decode/mtp.py 单测补充

image

develop分支,覆盖率56 ,miss行数199(40-50, 101, 104, 108, 135->137, 158, 196-197, 204, 207-233, 247-252, 286, 339->exit, 475, 498-505, 514, 552-558, 564-586, 597-674, 716-742, 748-751, 821-833, 870-882, 922, 946-948, 973-974, 989->996, 998, 1000->849, 1010-1091, 1094-1100, 1106-1129, 1143, 1161-1181, 1191, 1195, 1213-1214, 1219-1226)

image

当前pr,覆盖率81,miss行数 79

完成单测覆盖行数199-79=120

Modifications

add unittest tests/spec_decode/test_mtp_proposer.py

Usage or Command

no need

Accuracy Tests

no need

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings December 12, 2025 08:15
@paddle-bot
Copy link

paddle-bot bot commented Dec 12, 2025

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Dec 12, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds comprehensive unit tests for the MTP (Multi-Token-Prediction) proposer module (fastdeploy/spec_decode/mtp.py) as part of Hackathon 9th Sprint. The tests use extensive mocking to verify various initialization, configuration, and runtime behaviors of the MTPProposer class.

  • Adds 498 lines of unit tests covering MTPProposer initialization, configuration updates, KV cache management, and various runtime operations
  • Tests multiple scenarios including different request types (PREFILL, DECODE, PREEMPTED), chunked prefill, expert parallel, and multimodal inputs
  • Uses unittest.mock extensively to isolate the component under test

@@ -0,0 +1,526 @@
"""
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR title format is correct with the [CI] tag, but it should be translated to English to match standard GitHub PR practices. The Chinese title should be translated to something like: "[CI][Hackathon 9th Sprint No.12] Add unit tests for fastdeploy/spec_decode/mtp.py"

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +197 to +271
self.assertEqual(proposer.main_model_num_gpu_blocks, 20)
self.assertIn("free_list", proposer.model_inputs)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test insert_tasks_v1 with different request types"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

# Test with PREFILL request
request1 = Request(
request_id="test1",
prompt="test",
prompt_token_ids=[1, 2, 3, 4, 5],
prompt_token_ids_len=5,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request1.idx = 0
request1.task_type = RequestType.PREFILL
request1.prefill_start_index = 0
request1.prefill_end_index = 5
request1.output_token_ids = []
request1.block_tables = [0, 1]

# Test with DECODE request
request2 = Request(
request_id="test2",
prompt="test",
prompt_token_ids=[1, 2],
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request2.idx = 1
request2.task_type = RequestType.DECODE
request2.block_tables = [2, 3]

# Test with PREEMPTED request
request3 = Request(
request_id="test3",
prompt="test",
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request3.idx = 0
request3.task_type = RequestType.PREEMPTED

# Test splitwise_role == "decode"
self.fd_config.scheduler_config.splitwise_role = "decode"
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test creates a Request object with numerous parameters but doesn't validate that the MTPProposer correctly processes all these fields. Consider adding assertions to verify that the proposer correctly handles request.idx, request.block_tables, request.draft_token_ids, etc., after insertion.

Copilot uses AI. Check for mistakes.
Comment on lines +464 to +467
def test_extend_draft_token_and_run_impl(self, mock_ngram, mock_rope, mock_attn_backend, mock_model_loader):
"""Test _extend_draft_token_with_ngram_match and _run_impl"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test uses patch.object to mock internal methods but doesn't verify these mocked methods were called with the correct arguments. Consider adding assertions like mock_method.assert_called_once_with(expected_args) to ensure the implementation flow is correct.

Copilot uses AI. Check for mistakes.
Comment on lines +491 to +494
patch.object(proposer, "_update_status"),
):
proposer._run_impl(full_hidden_states)

Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for _empty_cache only verifies that paddle.device.cuda.empty_cache is called but doesn't test the actual cache clearing behavior or side effects. Consider verifying that the method actually affects the cache state or memory usage in a meaningful way.

Copilot uses AI. Check for mistakes.
Comment on lines +114 to +443
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_init_and_config_methods(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test initialization and config update methods"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn_backend.return_value.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

# Test _update_mtp_config
self.assertEqual(proposer.model_config.architectures[0], "ErnieMTPForCausalLM")
self.assertEqual(proposer.model_config.num_hidden_layers, 1)
self.assertEqual(proposer.speculative_config.model_type, "mtp")

# Test _get_cache_type
cache_type = proposer._get_cache_type()
self.assertIn(cache_type, ["uint8", "int8"])

# Test is_chunk_prefill_enabled
self.assertTrue(proposer.is_chunk_prefill_enabled())

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_dummy_prefill_inputs_and_kv_cache(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test dummy_prefill_inputs and initialize_kv_cache with different branches"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

# Test dummy_prefill_inputs with expert parallel
self.fd_config.parallel_config.enable_expert_parallel = True
proposer.dummy_prefill_inputs(num_tokens=100, batch_size=2, expected_decode_len=10)
self.assertGreater(proposer.model_inputs["seq_lens_encoder"][0].item(), 0)

# Test initialize_kv_cache with prefix caching
self.fd_config.cache_config.enable_prefix_caching = True
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False)
self.assertIn("caches", proposer.model_inputs)

# Test initialize_kv_cache with block_wise_fp8
self.fd_config.quant_config = QuantizationConfig({})
self.fd_config.quant_config.kv_cache_quant_type = "block_wise_fp8"
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False)

# Test initialize_kv_cache with profile=True
proposer.initialize_kv_cache(main_model_num_blocks=10, profile=True)

# Test clear_mtp_cache
proposer.clear_mtp_cache()
self.assertNotIn("caches", proposer.model_inputs)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_update_mtp_block_num(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test update_mtp_block_num"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)
proposer.update_mtp_block_num(num_gpu_blocks=20)
self.assertEqual(proposer.main_model_num_gpu_blocks, 20)
self.assertIn("free_list", proposer.model_inputs)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test insert_tasks_v1 with different request types"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

# Test with PREFILL request
request1 = Request(
request_id="test1",
prompt="test",
prompt_token_ids=[1, 2, 3, 4, 5],
prompt_token_ids_len=5,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request1.idx = 0
request1.task_type = RequestType.PREFILL
request1.prefill_start_index = 0
request1.prefill_end_index = 5
request1.output_token_ids = []
request1.block_tables = [0, 1]

# Test with DECODE request
request2 = Request(
request_id="test2",
prompt="test",
prompt_token_ids=[1, 2],
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request2.idx = 1
request2.task_type = RequestType.DECODE
request2.block_tables = [2, 3]

# Test with PREEMPTED request
request3 = Request(
request_id="test3",
prompt="test",
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request3.idx = 0
request3.task_type = RequestType.PREEMPTED

# Test splitwise_role == "decode"
self.fd_config.scheduler_config.splitwise_role = "decode"
proposer.insert_tasks_v1([request1], 1)

# Test with multimodal
proposer.enable_mm = True
request1.multimodal_inputs = {"attention_mask_offset": [0, 1, 2, 3, 4]}
proposer.model_inputs["attn_mask_offsets_full"] = paddle.zeros([2, 2048], dtype="int32")
proposer.model_inputs["attn_mask_offsets_decoder"] = paddle.zeros([2, 1], dtype="int32")
proposer.insert_tasks_v1([request1], 1)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_insert_prefill_inputs(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test insert_prefill_inputs with different roles and chunked prefill"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

request = Request(
request_id="test",
prompt="test",
prompt_token_ids=[1, 2, 3, 4, 5],
prompt_token_ids_len=5,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request.idx = 0
request.block_tables = [0, 1]
request.draft_token_ids = [10, 11]

# Test with prefill role
request.disaggregate_info = {"role": "prefill"}
proposer.insert_prefill_inputs([request], 1)
self.assertEqual(proposer.role, "prefill")

# Test with decode role
request.disaggregate_info = {"role": "decode"}
proposer.insert_prefill_inputs([request], 1)
self.assertEqual(proposer.role, "decode")

# Test with chunked prefill
self.fd_config.cache_config.enable_chunked_prefill = True
request.prefill_chunk_info = [3, 2]
request.disaggregate_info = None
proposer.insert_prefill_inputs([request], 1)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_forward_meta_and_exist_prefill(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test _initialize_forward_meta, _initialize_forward_meta_xpu, and exist_prefill"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)
proposer.initialize_kv_cache(main_model_num_blocks=10)

# Test _initialize_forward_meta
proposer._initialize_forward_meta(step_use_cudagraph=False)
self.assertIsNotNone(proposer.forward_meta)

# Test _initialize_forward_meta_xpu
proposer._initialize_forward_meta_xpu()
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")

# Test exist_prefill
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
result = proposer.exist_prefill()
self.assertEqual(result, 1)

proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
result = proposer.exist_prefill()
self.assertEqual(result, 0)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
@patch("fastdeploy.spec_decode.mtp.draft_model_preprocess")
@patch("fastdeploy.spec_decode.mtp.eagle_get_hidden_states")
def test_prepare_inputs_and_post_process(
self, mock_eagle, mock_preprocess, mock_rope, mock_attn_backend, mock_model_loader
):
"""Test _prepare_inputs and _post_process"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])
mock_eagle.return_value = paddle.zeros([2, 768], dtype="bfloat16")
mock_preprocess.return_value = None

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)
full_hidden_states = paddle.zeros([2, 768], dtype="bfloat16")

# Test _prepare_inputs
proposer._prepare_inputs(full_hidden_states)
mock_preprocess.assert_called()
mock_eagle.assert_called()

# Test _post_process with prefill role
proposer.role = "prefill"
sampled_token_ids = paddle.ones([2, 1], dtype="int64")
proposer._post_process(sampled_token_ids)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_update_task_chunk_prefill(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test update_task_chunk_prefill"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn = Mock()
mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

task = Mock()
task.idx = 0
task.prefill_chunk_info = [3, 2, 1]
task.prompt_token_ids = [1, 2, 3, 4, 5, 6]

# Test chunk_idx == len(prefill_chunk_info)
task.chunk_idx = 3
task.get = Mock(return_value=0)
proposer.update_task_chunk_prefill(task)

# Test chunk_idx < len - 1
task.chunk_idx = 0
proposer.update_task_chunk_prefill(task)

# Test last prefill
task.chunk_idx = 2
proposer.update_task_chunk_prefill(task)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
@patch("fastdeploy.spec_decode.mtp.draft_model_postprocess")
@patch("fastdeploy.spec_decode.mtp.mtp_step_paddle")
def test_update_status(self, mock_mtp_step, mock_postprocess, mock_rope, mock_attn_backend, mock_model_loader):
"""Test _update_status"""
mock_model = Mock()
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test repeatedly patches the same three decorators (get_model_loader, get_attention_backend, get_rope) across multiple test methods. This is a code smell indicating duplication. Consider using a test class-level patcher or a setUp helper method that applies these patches once for all tests, improving maintainability and reducing boilerplate.

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +112
self.target_model_inputs = {
"block_tables": paddle.zeros([2, 10], dtype="int32"),
"input_ids": paddle.zeros([2, 2048], dtype="int64"),
"seq_lens_this_time": paddle.zeros([2, 1], dtype="int32"),
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
"step_idx": paddle.zeros([2, 1], dtype="int64"),
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
"stop_nums": paddle.zeros([2, 1], dtype="int32"),
"pre_ids": paddle.zeros([2, 2048], dtype="int64"),
"output_cum_offsets": paddle.zeros([2], dtype="int32"),
"output_padding_offset": paddle.zeros([2], dtype="int32"),
"ids_remove_padding": paddle.zeros([2], dtype="int64"),
"batch_id_per_token": paddle.zeros([2], dtype="int32"),
"cu_seqlens_q": paddle.zeros([3], dtype="int32"),
"cu_seqlens_k": paddle.zeros([3], dtype="int32"),
"decoder_batch_ids": paddle.zeros([2], dtype="int32"),
"decoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
"decoder_num_blocks_cpu": paddle.zeros([2], dtype="int32").cpu(),
"decoder_num_blocks_device": paddle.zeros([2], dtype="int32"),
"decoder_chunk_size_device": paddle.zeros([2], dtype="int32"),
"max_len_tensor_cpu": paddle.zeros([2], dtype="int32").cpu(),
"encoder_batch_ids": paddle.zeros([2], dtype="int32"),
"encoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
"encoder_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(),
"kv_batch_ids": paddle.zeros([2], dtype="int32"),
"kv_tile_ids_per_batch": paddle.zeros([2], dtype="int32"),
"kv_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(),
"top_p": paddle.ones([2, 1], dtype="float32") * 0.9,
"top_k": paddle.zeros([2, 1], dtype="int32"),
"temperature": paddle.ones([2, 1], dtype="float32"),
"eos_token_id": paddle.ones([2], dtype="int64") * 2,
"penalty_score": paddle.ones([2, 1], dtype="float32"),
"frequency_score": paddle.zeros([2, 1], dtype="float32"),
"presence_score": paddle.zeros([2, 1], dtype="float32"),
"infer_seed": paddle.zeros([2, 1], dtype="int64"),
"max_dec_len": paddle.ones([2, 1], dtype="int64") * 512,
"min_dec_len": paddle.zeros([2, 1], dtype="int64"),
"bad_tokens": paddle.zeros([2], dtype="int64"),
"draft_tokens": paddle.zeros([2, 2], dtype="int64"),
"accept_tokens": paddle.zeros([2, 2], dtype="int64"),
"accept_num": paddle.ones([2], dtype="int32"),
"draft_logits": paddle.zeros([4, 32000], dtype="float32"),
"temp_scaled_logprobs": paddle.zeros([2], dtype="float32"),
"top_p_normalized_logprobs": paddle.zeros([2], dtype="float32"),
"encoder_block_lens": paddle.zeros([2, 1], dtype="int32"),
"cu_batch_token_offset": paddle.zeros([3], dtype="int32"),
"is_block_step": paddle.zeros([2], dtype="bool"),
}
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target_model_inputs dictionary is duplicated in the setUp method with extensive boilerplate. This large initialization block (lines 63-112) makes the setUp method hard to read and maintain. Consider extracting this into a helper method like _create_default_target_model_inputs() or moving it to the test utilities module for reuse across tests.

Copilot uses AI. Check for mistakes.
Comment on lines +120 to +123
mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000]))
mock_model_loader.return_value.load_model.return_value = mock_model
mock_attn_backend.return_value.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64])
mock_rope.return_value = paddle.zeros([1, 2048, 64])
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock return values use hardcoded shapes like [2, 12, 16, 64] and [2, 32000] without explaining their meaning. These should either use named constants or include comments explaining what each dimension represents (e.g., batch_size=2, num_layers=12, seq_len=16, head_dim=64).

Copilot uses AI. Check for mistakes.
Comment on lines +210 to +271
mock_attn_backend.return_value = mock_attn
mock_rope.return_value = paddle.zeros([1, 2048, 64])

proposer = MTPProposer(
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
)

# Test with PREFILL request
request1 = Request(
request_id="test1",
prompt="test",
prompt_token_ids=[1, 2, 3, 4, 5],
prompt_token_ids_len=5,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request1.idx = 0
request1.task_type = RequestType.PREFILL
request1.prefill_start_index = 0
request1.prefill_end_index = 5
request1.output_token_ids = []
request1.block_tables = [0, 1]

# Test with DECODE request
request2 = Request(
request_id="test2",
prompt="test",
prompt_token_ids=[1, 2],
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request2.idx = 1
request2.task_type = RequestType.DECODE
request2.block_tables = [2, 3]

# Test with PREEMPTED request
request3 = Request(
request_id="test3",
prompt="test",
prompt_token_ids=[1],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=[2],
arrival_time=0.0,
)
request3.idx = 0
request3.task_type = RequestType.PREEMPTED

# Test splitwise_role == "decode"
self.fd_config.scheduler_config.splitwise_role = "decode"
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test creates Request objects with complex initialization but doesn't test edge cases like empty prompt_token_ids, None values, or invalid indices. Consider adding negative test cases to verify the proposer handles invalid inputs gracefully.

Copilot uses AI. Check for mistakes.
Comment on lines +199 to +205

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader):
"""Test insert_tasks_v1 with different request types"""
mock_model = Mock()
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test uses the same mock setup repeatedly (mock_model with compute_logits, mock_attn with get_kv_cache_shape returning [2, 12, 16, 64], and mock_rope returning zeros). Consider extracting this into a helper method like _setup_common_mocks() to reduce duplication and improve test maintainability.

Copilot uses AI. Check for mistakes.
Comment on lines +421 to +438
# Test chunk_idx == len(prefill_chunk_info)
task.chunk_idx = 3
task.get = Mock(return_value=0)
proposer.update_task_chunk_prefill(task)

# Test chunk_idx < len - 1
task.chunk_idx = 0
proposer.update_task_chunk_prefill(task)

# Test last prefill
task.chunk_idx = 2
proposer.update_task_chunk_prefill(task)

@patch("fastdeploy.spec_decode.mtp.get_model_loader")
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
@patch("fastdeploy.spec_decode.mtp.get_rope")
@patch("fastdeploy.spec_decode.mtp.draft_model_postprocess")
@patch("fastdeploy.spec_decode.mtp.mtp_step_paddle")
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for _update_status only verifies that mtp_step_paddle is called but doesn't check the function was called with correct arguments or verify the state changes. Similarly, the test doesn't cover the ENABLE_V1_KVCACHE_SCHEDULER=True branch. Consider adding assertions for both branches and verifying the actual behavior.

Copilot uses AI. Check for mistakes.
@kesmeey kesmeey force-pushed the 12 branch 2 times, most recently from d9ecb6f to f839869 Compare December 12, 2025 09:29
@codecov-commenter
Copy link

codecov-commenter commented Dec 12, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@e927c65). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5533   +/-   ##
==========================================
  Coverage           ?   61.31%           
==========================================
  Files              ?      329           
  Lines              ?    41158           
  Branches           ?     6274           
==========================================
  Hits               ?    25238           
  Misses             ?    14030           
  Partials           ?     1890           
Flag Coverage Δ
GPU 61.31% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

)

# Test with PREFILL request
request1 = Request(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有三个 request,1和3的idx重复,并且只insert了request1,这里是为了测试什么呢?
建议把三个 request 都 insert,然后判断 share_inputs 插入相关的数值是否正确

self.fd_config.cache_config.enable_chunked_prefill = True
request.prefill_chunk_info = [3, 2]
request.disaggregate_info = None
proposer.insert_prefill_inputs([request], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,判断下插入后的数值是否正确

# Test _post_process with prefill role
proposer.role = "prefill"
sampled_token_ids = paddle.ones([2, 1], dtype="int64")
proposer._post_process(sampled_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这里只是执行,需要check下数值是否正确

proposer.model_inputs["seq_lens_this_time"] = proposer.seq_lens_this_time_buffer

# Test with ENABLE_V1_KVCACHE_SCHEDULER=False
with patch("fastdeploy.spec_decode.mtp.envs.ENABLE_V1_KVCACHE_SCHEDULER", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你好,问题修改完了,同时ci也运行完成

@CSWYF3634076
Copy link
Collaborator

@freeliuzc 麻烦再看下修改是否符合要求

@CSWYF3634076 CSWYF3634076 merged commit ac73165 into PaddlePaddle:develop Dec 17, 2025
16 of 18 checks passed
@CSWYF3634076
Copy link
Collaborator

0.1⭐️
@luotao1

chang-wenbin pushed a commit to chang-wenbin/FastDeploy that referenced this pull request Mar 2, 2026
…补充 (PaddlePaddle#5533)

* Add unit tests for MTPProposer class in spec_decode/mtp.py

* fix: remove non-existent QuantizationConfig import in test_mtp_proposer

* fix: add logprobs_mode attribute to FakeModelConfig

* fix: fix test failures in test_mtp_proposer - fix Mock setup, remove arrival_time, add missing keys

* fix: add seq_lens_this_time initialization and kv_cache init before insert_tasks_v1

* fix: check pos_emb_type attribute existence before assertion

* test: add minimal coverage for mtp cache type, mm init, preempted

* test: fix cache_type_branches unsupported platform on 12

* test: refine MTPProposer tests for cache type, requests and chunked prefill

* chore: remove stray spec_decode copy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants