[None][feat] Add a flag in trtllm serve to support overriding kv cache dtype#11487
[None][feat] Add a flag in trtllm serve to support overriding kv cache dtype#11487cjluo-nv merged 7 commits intoNVIDIA:mainfrom
Conversation
23973a7 to
6978b89
Compare
📝 WalkthroughWalkthroughThis pull request adds comprehensive support for "nvfp4" as a KV cache quantization dtype. Changes include a new CLI parameter Changes
Sequence DiagramsequenceDiagram
actor User
participant CLI as serve CLI
participant GetArgs as get_llm_args()
participant KvCfg as KvCacheConfig
participant TorchArgs as TorchLlmArgs
participant ModelLdr as ModelLoader
participant QuantCfg as QuantConfig
participant LinearMod as Linear Module
User->>CLI: --kv_cache_dtype nvfp4
CLI->>GetArgs: kv_cache_dtype="nvfp4"
GetArgs->>KvCfg: dtype="nvfp4"
KvCfg->>KvCfg: validate_dtype() normalizes to "nvfp4"
KvCfg-->>GetArgs: validated config
GetArgs->>TorchArgs: KvCacheConfig(dtype="nvfp4")
TorchArgs->>TorchArgs: sync_quant_config_with_kv_cache_config_dtype()
rect rgba(100, 150, 200, 0.5)
Note over TorchArgs: Maps dtype="nvfp4"<br/>to QuantAlgo.NVFP4
end
TorchArgs->>QuantCfg: kv_cache_quant_algo=NVFP4
TorchArgs-->>GetArgs: synced args
GetArgs-->>CLI: initialized llm_args
CLI->>ModelLdr: validate_and_set_kv_cache_quant()
ModelLdr->>QuantCfg: check for override
rect rgba(100, 200, 150, 0.5)
Note over ModelLdr: If explicit override differs,<br/>log warning & apply override
end
ModelLdr->>LinearMod: initialize kv_scales
LinearMod->>LinearMod: load_weights_fused_qkv_linear()
rect rgba(200, 150, 100, 0.5)
Note over LinearMod: Populate kv_scales &<br/>inv_kv_scales from metadata
end
LinearMod-->>ModelLdr: initialized module
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/llmapi/llm_utils.py (1)
361-386:⚠️ Potential issue | 🟠 MajorApply explicit kv_cache override when HF config has
kv_cache_quant_algo: null.When HF quant config exists but its
kv_cache_quant_algofield isnull(line 363 evaluates toFalse), the else block (line 380) only validatesquant_config.kv_cache_quant_algobut never applies the user'sexplicit_kv_cache_quant_algofromkv_cache_dtype. This causes the explicit override to be silently dropped.The PyTorch flow compensates through
validate_and_set_kv_cache_quantinmodel_loader.pycalled downstream, but any code path bypassing that compensation will lose the explicit setting.Fix
else: + if explicit_kv_cache_quant_algo is not None: + logger.info( + f"Setting kv_cache_quant_algo={explicit_kv_cache_quant_algo} from explicit kv_cache_config.dtype={kv_cache_dtype}." + ) + quant_config.kv_cache_quant_algo = explicit_kv_cache_quant_algo - if quant_config.kv_cache_quant_algo not in [ + elif quant_config.kv_cache_quant_algo not in [ None, QuantAlgo.FP8, QuantAlgo.NVFP4 ]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_utils.py` around lines 361 - 386, When HF quant config exists but hf_kv_cache_quant_algo is None, the code currently only validates quant_config.kv_cache_quant_algo and drops any explicit_kv_cache_quant_algo; change the else branch so that if explicit_kv_cache_quant_algo is provided you assign it into quant_config.kv_cache_quant_algo (i.e., set quant_config.kv_cache_quant_algo = explicit_kv_cache_quant_algo) before running any validation, otherwise keep the existing validation that only allows None/QuantAlgo.FP8/QuantAlgo.NVFP4; reference hf_kv_cache_quant_algo, explicit_kv_cache_quant_algo, and quant_config.kv_cache_quant_algo when making the update.tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
28-33:⚠️ Potential issue | 🟡 MinorStore
QuantAlgoenum members instead of.valuestrings in_KV_CACHE_MAPto match type hints.The field
kv_cache_quant_algois typed asOptional[QuantAlgo], but_KV_CACHE_MAPstoresQuantAlgo.FP8.valueandQuantAlgo.NVFP4.value(plain strings). This causes the code to set enum-typed fields to strings, violating the type contract. While equality checks work due toStrEnum, the assignment should use enum members to be type-correct and match the rest of the codebase.Suggested fix
_KV_CACHE_MAP = { - "fp8": QuantAlgo.FP8.value, - "nvfp4": QuantAlgo.NVFP4.value, + "fp8": QuantAlgo.FP8, + "nvfp4": QuantAlgo.NVFP4, "auto": "auto" }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 28 - 33, The _KV_CACHE_MAP currently maps to QuantAlgo.FP8.value and QuantAlgo.NVFP4.value (strings) which mismatches the kv_cache_quant_algo Optional[QuantAlgo] typing; change the map to use the enum members (QuantAlgo.FP8, QuantAlgo.NVFP4, and keep "auto" as appropriate) so assignments to kv_cache_quant_algo use actual QuantAlgo members; update any code that validates against _VALID_KV_CACHE_DTYPES if necessary to compare .value or the enum itself (refer to symbols _KV_CACHE_MAP, QuantAlgo, _VALID_KV_CACHE_DTYPES, and kv_cache_quant_algo to locate and fix the mappings and comparisons).
🧹 Nitpick comments (1)
tensorrt_llm/llmapi/llm_args.py (1)
3297-3313: Deadelsebranch —logger.warningis now unreachable
validate_dtype(added in this PR) constrainskv_cache_config.dtypeto exactly{"auto", "fp8", "nvfp4"}. All three values are now handled by explicitif/elifbranches, so theelse: logger.warning(...)at Line 3310 can never execute. This erodes the warning's intended safety-net role (alerting on an unhandled dtype) because it will silently never fire even when a future developer adds a new dtype to the validator but forgets to add a handler here.Consider either removing the dead branch or turning it into an
assert False/raiseto catch future omissions at development time:♻️ Proposed fix
elif self.kv_cache_config.dtype == 'nvfp4': self.quant_config.kv_cache_quant_algo = QuantAlgo.NVFP4 - else: - logger.warning( - f"Cannot sync quant_config.kv_cache_quant_algo with kv_cache_config.dtype of {self.kv_cache_config.dtype}, " - "please update the validator") + else: + # This branch is unreachable: validate_dtype restricts dtype to + # {"auto", "fp8", "nvfp4"}, all of which are handled above. + # If a new dtype is added to validate_dtype, add a handler here too. + raise AssertionError( + f"Unhandled kv_cache_config.dtype '{self.kv_cache_config.dtype}' in " + "sync_quant_config_with_kv_cache_config_dtype — update this method.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 3297 - 3313, The else branch in sync_quant_config_with_kv_cache_config_dtype is dead because validate_dtype restricts kv_cache_config.dtype to {"auto","fp8","nvfp4"}; replace the unreachable logger.warning with a fail-fast check (e.g., raise ValueError or assert False) so future additions to kv_cache_config.dtype break immediately if not handled—update the function sync_quant_config_with_kv_cache_config_dtype to raise an explicit error mentioning the unexpected kv_cache_config.dtype and include context like the field name and current value instead of the logger.warning; keep existing handling of 'auto', 'fp8' (set QuantAlgo.FP8) and 'nvfp4' (set QuantAlgo.NVFP4).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 510-526: The KV-scales loading block (references:
module.kv_scales, module.inv_kv_scales, weights, copy_weight) must be gated by
the TRTLLM_LOAD_KV_SCALES env var like the other methods; wrap the existing if
hasattr(module, "kv_scales") { ... } block with a condition that checks
os.environ.get("TRTLLM_LOAD_KV_SCALES", "1") == "1" before computing
k_scales/v_scales, calling copy_weight, and setting module.inv_kv_scales.data so
KV scales can be disabled via the environment variable.
In `@tests/unittest/llmapi/test_kv_cache_dtype_override.py`:
- Around line 33-35: The test uses pytest.raises(match="kv_cache_config.dtype
must be one of") which treats the dots as regex wildcards; update the assertion
in the test_kv_cache_dtype_override.py test so the literal string is
matched—either escape the dots in the pattern (e.g., "kv_cache_config\.dtype
must be one of") or wrap the expected message with re.escape before passing to
pytest.raises(match=...), keeping the assertion on
KvCacheConfig(dtype="float16") unchanged.
---
Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 28-33: The _KV_CACHE_MAP currently maps to QuantAlgo.FP8.value and
QuantAlgo.NVFP4.value (strings) which mismatches the kv_cache_quant_algo
Optional[QuantAlgo] typing; change the map to use the enum members
(QuantAlgo.FP8, QuantAlgo.NVFP4, and keep "auto" as appropriate) so assignments
to kv_cache_quant_algo use actual QuantAlgo members; update any code that
validates against _VALID_KV_CACHE_DTYPES if necessary to compare .value or the
enum itself (refer to symbols _KV_CACHE_MAP, QuantAlgo, _VALID_KV_CACHE_DTYPES,
and kv_cache_quant_algo to locate and fix the mappings and comparisons).
In `@tensorrt_llm/llmapi/llm_utils.py`:
- Around line 361-386: When HF quant config exists but hf_kv_cache_quant_algo is
None, the code currently only validates quant_config.kv_cache_quant_algo and
drops any explicit_kv_cache_quant_algo; change the else branch so that if
explicit_kv_cache_quant_algo is provided you assign it into
quant_config.kv_cache_quant_algo (i.e., set quant_config.kv_cache_quant_algo =
explicit_kv_cache_quant_algo) before running any validation, otherwise keep the
existing validation that only allows None/QuantAlgo.FP8/QuantAlgo.NVFP4;
reference hf_kv_cache_quant_algo, explicit_kv_cache_quant_algo, and
quant_config.kv_cache_quant_algo when making the update.
---
Nitpick comments:
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 3297-3313: The else branch in
sync_quant_config_with_kv_cache_config_dtype is dead because validate_dtype
restricts kv_cache_config.dtype to {"auto","fp8","nvfp4"}; replace the
unreachable logger.warning with a fail-fast check (e.g., raise ValueError or
assert False) so future additions to kv_cache_config.dtype break immediately if
not handled—update the function sync_quant_config_with_kv_cache_config_dtype to
raise an explicit error mentioning the unexpected kv_cache_config.dtype and
include context like the field name and current value instead of the
logger.warning; keep existing handling of 'auto', 'fp8' (set QuantAlgo.FP8) and
'nvfp4' (set QuantAlgo.NVFP4).
JunyiXu-nv
left a comment
There was a problem hiding this comment.
Changes in trtllm-serve LGTM. Left 1 comment.
|
Is it possible to add a unittest to guard this feature? You can reuse the existing CI model in |
6978b89 to
cb2cbb2
Compare
unittest added more on the trtllm serve side. |
Done |
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
629816a to
5e7d170
Compare
|
/bot run |
|
PR_Github #37046 [ run ] triggered by Bot. Commit: |
|
PR_Github #37046 [ run ] completed with state
|
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
|
/bot run |
|
PR_Github #37113 [ run ] triggered by Bot. Commit: |
|
PR_Github #37113 [ run ] completed with state |
|
/bot run |
|
PR_Github #37358 [ run ] triggered by Bot. Commit: |
|
PR_Github #37358 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37374 [ run ] triggered by Bot. Commit: |
|
PR_Github #37374 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37395 [ run ] triggered by Bot. Commit: |
|
PR_Github #37395 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37472 [ run ] triggered by Bot. Commit: |
|
PR_Github #37472 [ run ] completed with state |
|
/bot run |
|
PR_Github #37535 [ run ] triggered by Bot. Commit: |
|
PR_Github #37535 [ run ] completed with state |
…e dtype (NVIDIA#11487) Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Summary by CodeRabbit
Release Notes
New Features
--kv_cache_dtypeCLI parameter with options: auto, fp8, nvfp4Bug Fixes
Tests
Documentation
Description
Both FP8 and NVFP4 kv cache scales are unified as amax/448 now. So the user can ask the trtllm-serve to pick a preferred kv cache config for deployment with the same checkpoint. This PR adds the feature --kv_cache_dtype nvfp4 in trtllm-serve to force the trtllm-serve to use nvfp4 kv cache instead.
Test Coverage
Unittest, Manually load a checkpoint and force NVFP4 kv cache using trtll-serve and inspect the output.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.