[Intel HPU] enable chunked prefill#5903
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
This PR enables chunked prefill support for Intel HPU platform, allowing prefill operations to be split into smaller chunks when processing long sequences alongside decode operations.
Key changes:
- Enhanced HPU attention backend to support mixed encoder/decoder execution modes
- Modified forward metadata structure to separate encoder and decoder state management
- Added chunked prefill warmup and resource allocation logic
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| fastdeploy/worker/hpu_model_runner.py | Implements chunked prefill logic, adds mixed batch warmup, separates encoder/decoder metadata handling |
| fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py | Adds forward_mixed method to handle concurrent encoder/decoder batches, updates forward_extend and forward_decode to use separated metadata |
| fastdeploy/model_executor/forward_meta.py | Restructures HPUForwardMeta to maintain separate encoder/decoder state with dedicated fields for rotary embeddings, block metadata, and batch information |
| fastdeploy/engine/sched/resource_manager_v1.py | Adds HPU-specific token budget alignment logic to ensure chunk sizes are multiples of block_size |
| self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) | ||
| self.share_inputs["is_block_step"][idx : idx + 1] = False | ||
| # self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) | ||
| self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) |
There was a problem hiding this comment.
This line was previously commented out. Please ensure all corresponding references to is_chunk_step are properly updated and tested throughout the codebase, particularly in functions that consume this metadata.
There was a problem hiding this comment.
The variable name len shadows the built-in Python function. Consider renaming to seq_chunk_size or encoder_chunk_len for clarity and to avoid shadowing the built-in.
| len = int((tmp_out.shape[0] - total_batch_decoder) / total_batch_encoder) | |
| position = 0 | |
| for i in range(batch_ids_encoder.shape[0]): | |
| encoder_id = batch_ids_encoder[i].item() | |
| seq_len = seq_lens_encoder[encoder_id].item() | |
| output_data[position] = tmp_out[i * len + seq_len - 1] | |
| encoder_chunk_len = int((tmp_out.shape[0] - total_batch_decoder) / total_batch_encoder) | |
| position = 0 | |
| for i in range(batch_ids_encoder.shape[0]): | |
| encoder_id = batch_ids_encoder[i].item() | |
| seq_len = seq_lens_encoder[encoder_id].item() | |
| output_data[position] = tmp_out[i * encoder_chunk_len + seq_len - 1] |
There was a problem hiding this comment.
The index calculation i - total_batch_decoder appears incorrect. When i iterates from 0 to batch_ids_decoder.shape[0] - 1, this will produce negative indices for early iterations. The correct index should likely be total_batch_encoder * len + i.
| len = int((tmp_out.shape[0] - total_batch_decoder) / total_batch_encoder) | |
| position = 0 | |
| for i in range(batch_ids_encoder.shape[0]): | |
| encoder_id = batch_ids_encoder[i].item() | |
| seq_len = seq_lens_encoder[encoder_id].item() | |
| output_data[position] = tmp_out[i * len + seq_len - 1] | |
| position += 1 | |
| for i in range(batch_ids_decoder.shape[0]): | |
| output_data[position] = tmp_out[i - total_batch_decoder] | |
| block_len = int((tmp_out.shape[0] - total_batch_decoder) / total_batch_encoder) | |
| position = 0 | |
| for i in range(batch_ids_encoder.shape[0]): | |
| encoder_id = batch_ids_encoder[i].item() | |
| seq_len = seq_lens_encoder[encoder_id].item() | |
| output_data[position] = tmp_out[i * block_len + seq_len - 1] | |
| position += 1 | |
| decoder_start = total_batch_encoder * block_len | |
| for i in range(batch_ids_decoder.shape[0]): | |
| output_data[position] = tmp_out[decoder_start + i] |
There was a problem hiding this comment.
In the forward_mixed method's measurement mode branch for decoder, forward_meta.rotary_embs is used instead of forward_meta.rotary_embs_decoder. This inconsistency with the non-measurement branch (line 683) will cause incorrect rotary embeddings to be applied during measurement mode.
| forward_meta.rotary_embs, | |
| forward_meta.rotary_embs_decoder, |
There was a problem hiding this comment.
Line 376 sets forward_mode to ForwardMode.MIXED, then line 377-378 conditionally sets it again to the same value. The initial assignment on line 376 is redundant.
| forward_mode = ForwardMode.MIXED |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #5903 +/- ##
==========================================
Coverage ? 66.70%
==========================================
Files ? 347
Lines ? 44426
Branches ? 6823
==========================================
Hits ? 29636
Misses ? 12609
Partials ? 2181
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
EmmonsCurse
left a comment
There was a problem hiding this comment.
LGTM for skipping coverage.
Motivation
enable chunked prefill on intel hpu
depend on PaddlePaddle/PaddleCustomDevice#2324
Modifications
hpu attention backend
hpu forward metadata
hpu model runner
Usage or Command
use these parameters to enable chunked prefill
--enable-chunked-prefill
--max-num-batched-tokens 4096
Accuracy Tests
ERNIE-4.5-21B-A3B-Paddle
Accuracy: 0.920
Invalid: 0.001
Latency: 370.744 s
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.conducted by local tests
releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.