Skip to content

Prefix caching | Mamba memory only.#3657

Merged
lmcafee-nvidia merged 2 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba-memory-only
Mar 3, 2026
Merged

Prefix caching | Mamba memory only.#3657
lmcafee-nvidia merged 2 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba-memory-only

Conversation

@lmcafee-nvidia
Copy link
Copy Markdown
Contributor

Summary

  • Hybrid models (Transformer + Mamba) cannot skip prefill computation because Mamba layers maintain recurrent states that depend on the full sequence history, unlike attention KV cache blocks which are self-contained and reusable.
  • This PR adds a guard in _compute_prefix_match that forces prefix_skip_tokens = 0 when is_hybrid_model is True, so matched prefix blocks are still shared (saving memory) but all prompt tokens are still processed through the model (preserving Mamba state correctness).
  • Adds 4 tests in TestHybridModelMemoryOnly verifying: no prefill skipping, block reuse for memory savings, correct ref counts for shared blocks, and all prompt tokens present in context.

Details

When prefix caching is enabled for a hybrid model, the system operates in "memory-only" mode:

  • KV cache blocks are shared across requests with matching prefix hashes, reducing memory consumption.
  • Prefill is NOT skipped because Mamba layers must process the full sequence to reconstruct their internal states.

The change is a single 3-line guard in _compute_prefix_match (~line 1624 of dynamic_context.py):

# Hybrid models: disable prefill skipping (no Mamba states per block),
# but keep matched blocks for memory sharing.
if self.is_hybrid_model:
    prefix_skip_tokens = 0

Benchmarked on a 2B hybrid model (23 Mamba + 4 Attention + 23 MLP layers, 50 total) with 10 identical requests (644 tokens each):

  • 64.2% block savings (31.0 → 11.1 blocks used)
  • 0% prefill token reduction (6440 tokens in all configs), as expected
  • Token-for-token output correctness vs. prefix caching disabled

Test plan

  • test_no_prefill_skipping_for_hybrid_model: verifies prefix_skip_tokens == 0 and effective_chunk_length == chunk_length even when blocks match
  • test_matched_blocks_reused_saving_memory: verifies second request consumes no additional blocks from pool
  • test_ref_counts_incremented_for_matched_blocks: verifies matched blocks have ref_count == 2 after sharing
  • test_all_prompt_tokens_in_context: verifies all prompt tokens are active (none skipped) and kv_length_offset == 0
/opt/venv/bin/python -m torch.distributed.run --nproc-per-node 1 -m pytest \
  tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py -v

🤖 Generated with Claude Code

@lmcafee-nvidia lmcafee-nvidia requested review from a team as code owners March 2, 2026 20:14
@lmcafee-nvidia lmcafee-nvidia self-assigned this Mar 2, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 2, 2026 20:14
@lmcafee-nvidia lmcafee-nvidia changed the title Disable prefill skipping for hybrid models while preserving block sharing Prefix caching | Mamba memory only. Mar 2, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@lmcafee-nvidia lmcafee-nvidia requested a review from a team as a code owner March 2, 2026 20:38
@lmcafee-nvidia lmcafee-nvidia marked this pull request as draft March 3, 2026 02:48
@lmcafee-nvidia lmcafee-nvidia marked this pull request as ready for review March 3, 2026 02:49
…ring

Hybrid models (Transformer + Mamba) lack per-block Mamba states, so prefix
computation cannot be skipped. This adds a guard in _compute_prefix_match
that forces prefix_skip_tokens=0 when is_hybrid_model is True, ensuring all
tokens are recomputed while still sharing KV blocks for memory savings.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@lmcafee-nvidia lmcafee-nvidia force-pushed the prefix-caching-mamba-memory-only branch from e6fd0b4 to cc45bb3 Compare March 3, 2026 04:00
@lmcafee-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 13af54a

@lmcafee-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 13af54a

@lmcafee-nvidia lmcafee-nvidia added this pull request to the merge queue Mar 3, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22643519147

@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22645678779

Merged via the queue into NVIDIA:main with commit 6fc7690 Mar 3, 2026
73 of 82 checks passed
@lmcafee-nvidia lmcafee-nvidia deleted the prefix-caching-mamba-memory-only branch March 3, 2026 23:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants