Skip to content

Conversation

@cjluo-nv
Copy link
Collaborator

@cjluo-nv cjluo-nv commented Jan 16, 2026

What does this PR do?

Overview: ?

Unified the FP8 and NVFP4 kv cache scaling factor definition so the same checkpoint can be used for both FP8 and NVFP4 kv cache quantization deployment

Testing

Unit test

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

Release Notes

  • Refactor

    • Fixed KV cache maximum bound to 448 for FP8 and NVFP4 quantization, simplifying configuration logic.
  • Chores

    • Removed internal constants from public exports.

✏️ Tip: You can customize this high-level summary in your review settings.

Unified the FP8 and NVFP4 kv cache scaling factor definition so the same checkpoint can be used for both FP8 and NVFP4 kv cache quantization deployment

Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
@cjluo-nv cjluo-nv requested a review from a team as a code owner January 16, 2026 05:47
@cjluo-nv cjluo-nv requested a review from meenchen January 16, 2026 05:47
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Removed three constant imports and replaced dynamic kv_cache_max_bound calculation with a hard-coded value of 448 in the export function, accompanied by a clarifying comment about FP8 and NVFP4 KV cache quantization.

Changes

Cohort / File(s) Summary
KV Cache Configuration Simplification
modelopt/torch/export/unified_export_hf.py
Removed imports of KV_CACHE_FP8, KV_CACHE_NVFP4, and KV_CACHE_NVFP4_AFFINE constants. Eliminated dynamic kv_cache_max_bound calculation based on quantization algorithm and mapping; now hard-coded to 448 with clarifying comment.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Define kv cache scaling factor as amax / 448' directly describes the main change of hard-coding kv_cache_max_bound to 448 and unifying the scaling factor for both FP8 and NVFP4 quantization.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@cjluo-nv cjluo-nv requested a review from Edwardf0t1 January 16, 2026 05:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 656-660: In _export_transformers_checkpoint the variable
kv_cache_format is undefined when calling postprocess_state_dict—extract
kv_cache_format from quant_config (the dict fetched earlier) using
quant_config["quantization"]["kv_cache_quant_algo"] and pass that value into
postprocess_state_dict; update the code around where quantized_state_dict and
kv_cache_max_bound are set so kv_cache_format is defined (e.g., read from
quant_config before calling postprocess_state_dict) to avoid the NameError
involving kv_cache_format and ensure postprocess_state_dict receives the correct
format.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e6e4efd and 6019a6e.

📒 Files selected for processing (1)
  • modelopt/torch/export/unified_export_hf.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines 656 to 660
# We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization.
kv_cache_max_bound = 448
quantized_state_dict = postprocess_state_dict(
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's read the function containing lines 656-660 to understand the context
cat -n modelopt/torch/export/unified_export_hf.py | sed -n '630,680p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2295


🏁 Script executed:

# Check if kv_cache_format is defined anywhere in the function before line 659
rg -n 'kv_cache_format\s*=' modelopt/torch/export/unified_export_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the postprocess_state_dict function definition to see its signature
rg -n 'def postprocess_state_dict' -A 5 --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 460


🏁 Script executed:

# Search for all usages of kv_cache_format in the codebase
rg -n 'kv_cache_format' --type py -C 2

Repository: NVIDIA/Model-Optimizer

Length of output: 2059


🏁 Script executed:

# Check how get_quant_config returns data and what it contains regarding kv_cache
rg -n 'quant_config\[' modelopt/torch/export/unified_export_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the structure returned by get_quant_config
cat -n modelopt/torch/export/quant_utils.py | sed -n '1220,1290p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3458


Critical: kv_cache_format is undefined and will cause a NameError at runtime.

The variable kv_cache_format is passed to postprocess_state_dict at line 659, but it is never defined in the _export_transformers_checkpoint function. The value should be extracted from quant_config, which is obtained at line 645 and contains the kv cache quantization format in quant_config["quantization"]["kv_cache_quant_algo"].

🔧 Suggested fix
     # We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization.
     kv_cache_max_bound = 448
+    kv_cache_format = quant_config["quantization"].get("kv_cache_quant_algo")
     quantized_state_dict = postprocess_state_dict(
         quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
     )
🤖 Prompt for AI Agents
In `@modelopt/torch/export/unified_export_hf.py` around lines 656 - 660, In
_export_transformers_checkpoint the variable kv_cache_format is undefined when
calling postprocess_state_dict—extract kv_cache_format from quant_config (the
dict fetched earlier) using quant_config["quantization"]["kv_cache_quant_algo"]
and pass that value into postprocess_state_dict; update the code around where
quantized_state_dict and kv_cache_max_bound are set so kv_cache_format is
defined (e.g., read from quant_config before calling postprocess_state_dict) to
avoid the NameError involving kv_cache_format and ensure postprocess_state_dict
receives the correct format.

@codecov
Copy link

codecov bot commented Jan 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.20%. Comparing base (1cc8e6b) to head (33b6dcb).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #790   +/-   ##
=======================================
  Coverage   74.20%   74.20%           
=======================================
  Files         192      192           
  Lines       19238    19238           
=======================================
  Hits        14276    14276           
  Misses       4962     4962           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv enabled auto-merge (squash) January 16, 2026 17:29
@cjluo-nv cjluo-nv merged commit b0e7d9f into main Jan 20, 2026
36 checks passed
@cjluo-nv cjluo-nv deleted the cjluo-nv-patch-3 branch January 20, 2026 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants