[#13580][fix] AutoDeploy: Support Gemma3n/4 E2B variants#13630
Conversation
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
Gemma4 E2B StatusLatest E2E RunCommand: bash -ic "f4 && python examples/auto_deploy/build_and_run_ad.py --model google/gemma-4-E2B-it --args.yaml-extra examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml"Result: passed with exit code 0. Key signals:
Workspace after the run had no tracked changes. Only local untracked artifacts remained: Example Prompt Responses
|
Gemma3n E2B StatusLatest E2E RunCommand: bash -ic "f4 && python examples/auto_deploy/build_and_run_ad.py --model google/gemma-3n-E2B-it --args.yaml-extra examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml"Result: passed with exit code 0. Key signals:
Workspace after the run had no tracked changes. Only local untracked artifacts remained: Example Prompt Responses
|
ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (24)
📝 WalkthroughWalkthroughThis PR introduces Gemma3n and Gemma4 AutoDeploy support with shared-KV attention, per-layer inputs, and dynamic MLP scaling. It enhances the Triton paged attention backend to handle per-sequence cache metadata and read-only cache access, strengthens CUDA graph compilation with resource-input awareness, and provides comprehensive test coverage and accuracy benchmarks for the new models. ChangesGemma Model Support and Attention Optimization
Sequence DiagramsequenceDiagram
participant Client
participant Gemma4ForCausalLM
participant Gemma4TextModel
participant Gemma4TextDecoderLayer
participant Gemma4TextAttention
participant TritonPagedAttention
Client->>Gemma4ForCausalLM: forward(input_ids, per_layer_inputs)
Gemma4ForCausalLM->>Gemma4TextModel: get_per_layer_inputs(input_ids)
Gemma4TextModel-->>Gemma4ForCausalLM: per_layer_inputs (projected)
Gemma4ForCausalLM->>Gemma4TextModel: forward(per_layer_inputs=...)
Gemma4TextModel->>Gemma4TextModel: embed tokens & compute per-layer contributions
loop for each decoder layer
Gemma4TextModel->>Gemma4TextDecoderLayer: forward(hidden_states, per_layer_input, shared_kv_states)
Gemma4TextDecoderLayer->>Gemma4TextAttention: forward(hidden_states, shared_kv_states)
alt is KV-shared layer
Gemma4TextAttention->>Gemma4TextAttention: fetch (k, v) from shared_kv_states
else compute new KV
Gemma4TextAttention->>Gemma4TextAttention: compute (k, v) with RoPE
Gemma4TextAttention->>Gemma4TextAttention: store (k, v) in shared_kv_states
end
Gemma4TextAttention->>TritonPagedAttention: triton_paged_mha_with_cache(...)
TritonPagedAttention-->>Gemma4TextAttention: attention output
Gemma4TextDecoderLayer->>Gemma4TextDecoderLayer: inject per_layer_input via gated/projection/norm
Gemma4TextDecoderLayer-->>Gemma4TextModel: updated hidden_states
end
Gemma4TextModel-->>Gemma4ForCausalLM: final hidden_states
Gemma4ForCausalLM-->>Client: logits
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml (1)
7-30: 🏗️ Heavy liftAdd a registry-level smoke test for the Gemma E2B configs.
This PR changes deployment entry points, but the added coverage only exercises lower-level model/attention pieces. A small AutoDeploy smoke that resolves both
examples/auto_deploy/model_registry/configs/gemma4_e2b.yamlandexamples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yamlthrough the model registry would catch config/factory/tokenizer wiring regressions before release. If that test lives undertests/integration/defs/, please also register it in the appropriate QA functional list.As per coding guidelines, “Coverage expectations: Assess whether new/changed tests cover happy path...” and “If the Gemma3n/Gemma4 AutoDeploy fixes require end-to-end functional coverage ... add the corresponding new/updated test cases here ... so they execute in the scheduled GPU functional QA run.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml` around lines 7 - 30, Add a registry-level smoke test that resolves the Gemma E2B configs via the model registry: create a new integration test under tests/integration/defs (e.g., test_gemma_e2b_registry_smoke) that loads the model registry and attempts to resolve both examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml and examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml, asserting successful factory lookup and tokenizer resolution for model_factory Gemma4ForConditionalGeneration and the tokenizer google/gemma-4-E2B-it; register this test in the QA functional list so it runs in scheduled GPU functional QA.tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)
955-1029: 🏗️ Heavy liftAdd perf coverage for the SDPA/shared-KV dispatch rewrite.
This path changes gather shape, masking, and the read-only shared-KV execution flow in a latency-sensitive attention kernel, but the PR only adds functional unit coverage. Please add or update a perf sanity case and wire it into
tests/integration/test_lists/test-db/l0_perf.yml; add a QAllm_perf_*entry as well if this should run in scheduled coverage.As per coding guidelines, “If the PR touches performance-sensitive paths ... check whether a perf test entry is present or updated in: (a) tests/integration/test_lists/test-db/l0_perf.yml ... and (b) tests/integration/test_lists/qa/llm_perf_*.yml ...”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 955 - 1029, The SDPA/shared-KV dispatch added in tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (look for use_sdpa, _fast_gather_sdpa_kernel, k_sdpa/v_sdpa and the scaled_dot_product_attention SDPA path) needs perf test coverage: add or update a perf sanity case exercising the new SDPA/shared-KV path and its altered gather/mask behavior, then wire that test into tests/integration/test_lists/test-db/l0_perf.yml and, if this should run in scheduled QA, add a corresponding entry under tests/integration/test_lists/qa/llm_perf_*.yml (use a descriptive name like llm_perf_sdpa_shared_kv) so the latency-sensitive path is included in perf runs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py`:
- Around line 886-923: The code only sets tokenizer.chat_template from a
chat_template.jinja file via cached_file(_CHAT_TEMPLATE_FILE); add a fallback to
read the chat template from the already-loaded tokenizer config (config) when
template_path is None: after the existing template_path check, if template_path
is None and config.get("chat_template") is truthy, set tokenizer.chat_template =
config["chat_template"]; ensure this logic lives alongside the existing
cached_file call (referencing cached_file, _CHAT_TEMPLATE_FILE, config, cls, and
tokenizer.chat_template) so file-based template still takes precedence over
tokenizer_config.json.
---
Nitpick comments:
In `@examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml`:
- Around line 7-30: Add a registry-level smoke test that resolves the Gemma E2B
configs via the model registry: create a new integration test under
tests/integration/defs (e.g., test_gemma_e2b_registry_smoke) that loads the
model registry and attempts to resolve both
examples/auto_deploy/model_registry/configs/gemma4_e2b.yaml and
examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml, asserting
successful factory lookup and tokenizer resolution for model_factory
Gemma4ForConditionalGeneration and the tokenizer google/gemma-4-E2B-it; register
this test in the QA functional list so it runs in scheduled GPU functional QA.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 955-1029: The SDPA/shared-KV dispatch added in
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
(look for use_sdpa, _fast_gather_sdpa_kernel, k_sdpa/v_sdpa and the
scaled_dot_product_attention SDPA path) needs perf test coverage: add or update
a perf sanity case exercising the new SDPA/shared-KV path and its altered
gather/mask behavior, then wire that test into
tests/integration/test_lists/test-db/l0_perf.yml and, if this should run in
scheduled QA, add a corresponding entry under
tests/integration/test_lists/qa/llm_perf_*.yml (use a descriptive name like
llm_perf_sdpa_shared_kv) so the latency-sensitive path is included in perf runs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: f6647ce2-c5a7-4e0c-9e25-7412170a6dcd
📒 Files selected for processing (8)
examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yamlexamples/auto_deploy/model_registry/configs/gemma4_e2b.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.pytests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
7476a5c to
ad48358
Compare
|
/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast |
2 similar comments
|
/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast |
|
/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast |
|
PR_Github #46908 [ run ] triggered by Bot. Commit: |
|
/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast |
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
|
PR_Github #46917 [ run ] triggered by Bot. Commit: |
|
PR_Github #46908 [ run ] completed with state |
|
PR_Github #46917 [ run ] completed with state
|
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
f27c805 to
81ab8d4
Compare
|
/bot run |
|
PR_Github #48006 [ run ] triggered by Bot. Commit: |
|
PR_Github #48006 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #48044 [ run ] triggered by Bot. Commit: |
|
PR_Github #48044 [ run ] completed with state
|
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #48424 [ run ] triggered by Bot. Commit: |
|
PR_Github #48424 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #48429 [ run ] triggered by Bot. Commit: |
|
PR_Github #48429 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #48619 [ run ] triggered by Bot. Commit: |
|
PR_Github #48619 [ run ] completed with state |
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #48997 [ run ] triggered by Bot. Commit: |
|
PR_Github #48997 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #49247 [ run ] triggered by Bot. Commit: |
|
PR_Github #49247 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Improvements
Documentation
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.