Skip to content

DeepSeek-V3.2-Exp NVFP4 quantized model fails with flash_mla KV cache dtype incompatibility (TensorRT-LLM 1.2.0rc7) #763

@evgeniiperepelkin

Description

@evgeniiperepelkin

Description
After successfully quantizing DeepSeek-V3.2-Exp to NVFP4 using Model-Optimizer, the model fails to run inference with trtllm-serve. The flash_mla kernel requires KV cache in BFloat16, but kv_cache_config.dtype only accepts fp8, nvfp4, or auto.
Environment

GPU: 8x NVIDIA H200 141GB (SM 9.0, Hopper)
CUDA: 12.8
Driver: 570.x
TensorRT-LLM: 1.2.0rc7
Model-Optimizer: Latest (pip install)
Base Model: deepseek-ai/DeepSeek-V3.2-Exp
Quantized Format: NVFP4
Docker Image: nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc7

Steps to Reproduce
Step 1: Quantization (Completed Successfully)

export HF_FP8_CKPT=/data/models/DeepSeek-V3.2-Exp
export DS_CKPT=/data/checkpoints/deepseek-converted
export FP4_QUANT_PATH=/data/checkpoints/deepseek-ptq
export HF_FP4_PATH=/data/models/DeepSeek-V3.2-NVFP4
huggingface-cli download deepseek-ai/DeepSeek-V3.2-Exp --local-dir $HF_FP8_CKPT
git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git
cd DeepSeek-V3.2-Exp && git checkout 87e509a
pip install nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com/
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
pip install -r inference/requirements.txt
python inference/convert.py --hf-ckpt-path $HF_FP8_CKPT --save-path $DS_CKPT --n-experts 256 --model-parallel 8
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH
./quantize_fp8_to_nvfp4.sh --amax_path $FP4_QUANT_PATH --fp4_output_path $HF_FP4_PATH --fp8_hf_path $HF_FP8_CKPT --world_size 8

Result: Quantization completed successfully. Output: ~387GB model (163 safetensors files).
Step 2: Inference Attempt
Config file /data/extra-llm-api-config.yml:

kv_cache_config:
enable_block_reuse: false
dtype: auto

Run command:

docker run -d --name deepseek-serve --gpus all --ipc=host --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm-size=32g --ulimit memlock=-1 --ulimit stack=67108864 -v /data:/data -p 12345:12345 nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc7 trtllm-serve /data/models/DeepSeek-V3.2-NVFP4 --backend pytorch --max_batch_size 32 --max_num_tokens 120000 --tp_size 8 --ep_size 8 --pp_size 1 --kv_cache_free_gpu_memory_fraction 0.35 --extra_llm_api_options /data/extra-llm-api-config.yml --host 0.0.0.0 --port 12345
Error

Model loads successfully, but crashes with:
[TRT-LLM] [RANK 0] [E] Failed to initialize executor on rank 0:
RuntimeError: Expected kv.dtype() == torch::kBFloat16 to be true, but got false
The error originates from flash_mla_cuda.sparse_prefill_fwd kernel.

Root Cause
The flash_mla kernel (Multi-head Latent Attention for DeepSeek) requires KV cache in BFloat16 format, but:
kv_cache_config.dtype only accepts: fp8, nvfp4, auto
Setting dtype: bfloat16 results in: ValueError: Accepted types are "('fp8', 'nvfp4', 'auto')"
auto does not select BFloat16
This creates an impossible configuration where the kernel requirement cannot be satisfied through the available API.
Additional Warnings
You are using a model of type deepseek_v32 to instantiate a model of type deepseek_v3. This is not supported for all configurations of models and can yield errors.
Expected Behavior
The NVFP4 quantized DeepSeek-V3.2-Exp model should be deployable with trtllm-serve using PyTorch backend.

Questions:

  1. Is DeepSeek-V3.2-Exp officially supported for NVFP4 quantization?
  2. Can bfloat16 be added as an accepted value for kv_cache_config.dtype?
  3. Which TensorRT-LLM version is recommended for DeepSeek V3.2 NVFP4

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions