Skip to content

Added support for KV cache quantization for vllm fakequant#686

Merged
kinjalpatel27 merged 6 commits intomainfrom
kinjal/vllm_att_quant
Dec 16, 2025
Merged

Added support for KV cache quantization for vllm fakequant#686
kinjalpatel27 merged 6 commits intomainfrom
kinjal/vllm_att_quant

Conversation

@kinjalpatel27
Copy link
Copy Markdown
Contributor

@kinjalpatel27 kinjalpatel27 commented Dec 13, 2025

What does this PR do?

Type of change: New feature

Overview:

  • Added support to quantize KV cache in vLLM fakequant by adding quantization support for Attention
  • Modified initialization of parallel state to incorporate vLLM parallel state groups for correct quantization parameter syncing

Usage

Please refer to Readme

KV_QUANT_CFG=NVFP4_KV_CFG QUANT_CFG=NVFP4_DEFAULT_CFG python vllm_serve_fakequant.py meta-llama/Llama-3.2-1B-Instruct --served-model-name meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0 --port 8001 --trust-remote-code 

Testing

Locally tested KV Cache quantization

model.layers.0.self_attn.qkv_proj.input_quantizer                  TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=5.0312 calibrator=MaxCalibratorquant)
model.layers.0.self_attn.qkv_proj.weight_quantizer                 TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.6758 calibrator=MaxCalibratorquant)
model.layers.0.self_attn.qkv_proj.output_quantizer                  TensorQuantizer(disabled)
model.layers.0.self_attn.o_proj.input_quantizer                     TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits':(4,3)}, amax=1.3438 calibrator=MaxCalibrator quant)
model.layers.0.self_attn.o_proj.weight_quantizer                    TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic','scale_bits': (4, 3)}, amax=0.3145 calibrator=MaxCalibratorquant)
model.layers.0.self_attn.o_proj.output_quantizer                    TensorQuantizer(disabled)
model.layers.0.self_attn.attn.q_bmm_quantizer                       TensorQuantizer(disabled)
model.layers.0.self_attn.attn.k_bmm_quantizer                       TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=13.8125 calibrator=MaxCalibrator quant)
model.layers.0.self_attn.attn.v_bmm_quantizer                       TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic','scale_bits': (4, 3)}, amax=1.3438 calibrator=MaxCalibratorquant)
model.layers.0.mlp.gate_up_proj.input_quantizer                     TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=3.2812 calibrator=MaxCalibratorquant)
model.layers.0.mlp.gate_up_proj.weight_quantizer                    TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.5938 calibrator=MaxCalibratorquant)
model.layers.0.mlp.gate_up_proj.output_quantizer                    TensorQuantizer(disabled)
model.layers.0.mlp.down_proj.input_quantizer                        TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=33.7500 calibrator=MaxCalibrator quant)
model.layers.0.mlp.down_proj.weight_quantizer                       TensorQuantizer((2, 1) bit fake block_sizes={-1: 16, 'type': 'dynamic', 'scale_bits': (4, 3)}, amax=0.6211 calibrator=MaxCalibratorquant)
model.layers.0.mlp.down_proj.output_quantizer                       TensorQuantizer(disabled)

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: NA
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

@kinjalpatel27 kinjalpatel27 requested review from a team as code owners December 13, 2025 02:02
@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 13, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.72%. Comparing base (10d0fec) to head (88c4452).
⚠️ Report is 8 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #686   +/-   ##
=======================================
  Coverage   74.72%   74.72%           
=======================================
  Files         192      192           
  Lines       18828    18828           
=======================================
  Hits        14070    14070           
  Misses       4758     4758           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
try:
dp_group = get_dp_group().device_group
except Exception:
dp_group = -1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to throw out an error here?

Copy link
Copy Markdown
Contributor Author

@kinjalpatel27 kinjalpatel27 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cjluo-nv. I removed try-exceptin f103458 so if there is any error it would fail.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@kinjalpatel27 kinjalpatel27 merged commit 7233616 into main Dec 16, 2025
36 checks passed
@kinjalpatel27 kinjalpatel27 deleted the kinjal/vllm_att_quant branch December 16, 2025 17:36
@kinjalpatel27 kinjalpatel27 restored the kinjal/vllm_att_quant branch January 1, 2026 01:46
@kinjalpatel27 kinjalpatel27 deleted the kinjal/vllm_att_quant branch January 1, 2026 01:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants