Remove state init#4604
Conversation
grimoire
commented
May 20, 2026
- fill conv state in model when forward
- update gate to ignore init state of gdr
- remove init cache, it take too much times
- l2 norm before repeat interleave
There was a problem hiding this comment.
Pull request overview
This PR removes explicit state-cache initialization for the GatedDelta/SSM path by making the model handle “init state” behavior during forward, and optimizes the kv head replication path by applying Q/K L2-normalization before repeat_interleave to reduce overhead.
Changes:
- Add init-state metadata (
is_init,is_init_token) toGatedDeltaMeta, zero conv initial states on init, and mask GDR gate for init tokens. - Move
kv_ratioreplication logic intoGatedDelta(and add a helper that normalizes before replication). - Remove
StateCacheEngine.init_cachesand its call site during model forward.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| lmdeploy/pytorch/nn/gated_delta.py | Adds init-token handling and moves kv replication + (optional) Q/K L2-norm before replication into the GatedDelta wrapper. |
| lmdeploy/pytorch/models/qwen3_5.py | Stops repeating Q/K in the model and passes kv_ratio into GatedDelta. |
| lmdeploy/pytorch/engine/model_agent/agent.py | Removes the state cache initialization call during forward. |
| lmdeploy/pytorch/engine/cache_engine.py | Removes StateCacheEngine.init_caches implementation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| self.is_init = None | ||
| self.is_init_token = None | ||
| if not self.is_decoding: |
There was a problem hiding this comment.
will it work for dp>1?
There was a problem hiding this comment.
With self.is_init = (attn_metadata.kv_seqlens - attn_metadata.q_seqlens) == 0 condition, I think it should be ok?
| beta = b.sigmoid() | ||
| # If the model is loaded in fp16, without the .float() here, A might be -inf | ||
| g = self.get_A_log_exp() * F.softplus(a.float() + self.dt_bias) | ||
| if self.kv_ratio > 1: |
There was a problem hiding this comment.
Should we update in qwen3 next similarly?
https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/pytorch/models/qwen3_next.py#L190