Skip to content

Fix(models/siglip): Add compatibility for Gemma models quantized by llm-compressor #19643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2025

Conversation

Flink-ddd
Copy link
Contributor

@Flink-ddd Flink-ddd commented Jun 14, 2025

Description

This PR resolves a KeyError that occurs when attempting to serve a Gemma-3 model that has been quantized by the vllm-project/llm-compressor library. It makes the SigLIP vision model loader more robust to variations in weight naming conventions.

This directly addresses the root cause of the bug reported in vllm-project/llm-compressor#1546.

Root Cause

The KeyError arises from a mismatch in weight-naming conventions between the vLLM's internal SiglipVisionModel architecture and the standard model artifacts produced by tools like llm-compressor.

Specifically, the SiglipVisionModel in vLLM prepends an additional vision_model. prefix to its parameters (e.g., expecting vision_model.vision_model.encoder...), while the quantized model artifact uses the standard Hugging Face naming convention (e.g., vision_model.encoder...).

The existing load_weights function in siglip.py strictly expected the internal vLLM format, leading to a KeyError when it encountered a weight name from the quantized model.

Solution

This patch introduces a flexible remapping logic within the load_weights function in vllm/model_executor/models/siglip.py.

  • Before attempting to access a parameter, the code now checks if the incoming weight name from the file exists in the model's parameter dictionary.
  • If it doesn't exist, it programmatically attempts to fix the common prefix mismatch by prepending vision_model..
  • If the remapped name exists, it uses the new name to load the weight and prints an informational log.
  • If the remapped name still doesn't exist, it prints a warning and skips the weight, preventing the server from crashing.

This makes the loader robust to this naming discrepancy, resolving the KeyError.

Testing

The fix was verified in a cloud GPU environment (Colab A100) using the following methodology:

  1. A quantized Gemma-3 model artifact was prepared, which was confirmed to trigger the KeyError with the original vLLM code.
  2. The official pre-compiled vllm package was installed.
  3. The siglip.py file within the installation was "hot-patched" with the code from this PR.
  4. The vllm serve command was re-run against the quantized model.

Result:

  • The custom INFO: Remapping weight... logs were observed, confirming the patch was being correctly executed.
  • The KeyError was successfully resolved.
  • The vllm server launched successfully and was able to serve the quantized model.

Thank you for considering this contribution!

Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@@ -516,6 +516,13 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
if name not in params_dict:
potential_name = f"vision_model.{name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some comment to explain this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @houseroad, thank you so much for the quick and insightful review!

  1. Absolutely, that's a great suggestion. I've added a detailed comment to the code to explain the purpose of the remapping logic.
  2. You've pointed out the core issue perfectly regarding the double prefix. My current fix in load_weights is indeed a workaround at the loading stage. A more fundamental fix at the model's __init__ level would certainly be cleaner. I'm happy to explore that path. Could you provide any pointers on the preferred way to handle this prefixing logic within the vLLM architecture? Or would you prefer to merge the current loader-side fix first to resolve the user-facing bug, and then create a separate issue for the architectural refactoring?

I'll push the commit with the added comments shortly. Thanks again!

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we can fix the weights loading logic to avoid vision_model.vision_model double prefix?

@Flink-ddd Flink-ddd force-pushed the fix-gemma3-keyerror branch from bec5429 to 5e0cded Compare June 18, 2025 02:34
@Flink-ddd
Copy link
Contributor Author

Hi @houseroad , thank you so much for the quick and insightful review!

I change my method and refer other models plemetation and just finish verify my new method, pass successful.

!python examples/offline_inference/vision_language.py -m gemma3

output:

2025-06-19 12:17:05.017497: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-19 12:17:05.035118: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750335425.056146    4463 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750335425.062547    4463 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-19 12:17:05.083682: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO 06-19 12:17:07 [__init__.py:244] Automatically detected platform cuda.
config.json: 100% 855/855 [00:00<00:00, 5.86MB/s]
preprocessor_config.json: 100% 570/570 [00:00<00:00, 4.31MB/s]
INFO 06-19 12:17:27 [config.py:831] This model supports multiple tasks: {'generate', 'classify', 'score', 'reward', 'embed'}. Defaulting to 'generate'.
tokenizer_config.json: 100% 1.16M/1.16M [00:00<00:00, 1.36MB/s]
INFO 06-19 12:17:28 [config.py:1444] Using max model len 2048
INFO 06-19 12:17:28 [config.py:2197] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 06-19 12:17:28 [config.py:2234] max_num_batched_tokens (8192) exceeds max_num_seqs* max_model_len (4096). This may lead to unexpected behavior.
tokenizer.model: 100% 4.69M/4.69M [00:00<00:00, 215MB/s]
tokenizer.json: 100% 33.4M/33.4M [00:00<00:00, 335MB/s]
added_tokens.json: 100% 35.0/35.0 [00:00<00:00, 316kB/s]
special_tokens_map.json: 100% 662/662 [00:00<00:00, 5.43MB/s]
generation_config.json: 100% 215/215 [00:00<00:00, 1.59MB/s]
INFO 06-19 12:17:35 [core.py:460] Waiting for init message from front-end.
INFO 06-19 12:17:35 [core.py:70] Initializing a V1 LLM engine (v0.1.dev7183+g727cb28) with config: model='google/gemma-3-4b-it', speculative_config=None, tokenizer='google/gemma-3-4b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=google/gemma-3-4b-it, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
WARNING 06-19 12:17:36 [utils.py:2756] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0xe44187ca6d0>
INFO 06-19 12:17:36 [parallel_state.py:1072] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
processor_config.json: 100% 70.0/70.0 [00:00<00:00, 550kB/s]
chat_template.json: 100% 1.61k/1.61k [00:00<00:00, 13.2MB/s]
WARNING 06-19 12:17:47 [gemma3_mm.py:124] `do_pan_and_scan=True` has suboptimal results on V1 because of the simplified attention pattern being used.
WARNING 06-19 12:17:47 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 06-19 12:17:47 [gpu_model_runner.py:1627] Starting to load model google/gemma-3-4b-it...
INFO 06-19 12:17:48 [gpu_model_runner.py:1632] Loading model from scratch...
INFO 06-19 12:17:48 [cuda.py:259] Using Flash Attention backend on V1 engine.
INFO 06-19 12:17:49 [weight_utils.py:292] Using model weights format ['*.safetensors']
model-00001-of-00002.safetensors: 100% 4.96G/4.96G [00:16<00:00, 307MB/s]
model-00002-of-00002.safetensors: 100% 3.64G/3.64G [00:37<00:00, 97.2MB/s]
INFO 06-19 12:18:43 [weight_utils.py:308] Time spent downloading weights for google/gemma-3-4b-it: 54.376797 seconds
model.safetensors.index.json: 100% 90.6k/90.6k [00:00<00:00, 21.5MB/s]
Loading safetensors checkpoint shards: 100% 2/2 [00:02<00:00,  1.25s/it]
INFO 06-19 12:18:46 [default_loader.py:272] Loading weights took 2.62 seconds
INFO 06-19 12:18:47 [gpu_model_runner.py:1656] Model loading took 8.5834 GiB and 58.401776 seconds
INFO 06-19 12:18:48 [gpu_model_runner.py:2011] Encoder cache will be initialized with a budget of 8192 tokens, and profiled with 2 image items of the maximum feature size.
INFO 06-19 12:19:02 [backends.py:508] Using cache directory: /root/.cache/vllm/torch_compile_cache/98cd808517/rank_0_0/backbone for vLLM's torch.compile
INFO 06-19 12:19:02 [backends.py:519] Dynamo bytecode transform time: 13.46 s
[rank0]:W0619 12:19:04.701000 4710 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode
INFO 06-19 12:19:08 [backends.py:181] Cache the graph of shape None for later use
INFO 06-19 12:19:57 [backends.py:193] Compiling a graph for general shape takes 54.23 s
INFO 06-19 12:20:46 [monitor.py:34] torch.compile takes 67.70 s in total
INFO 06-19 12:20:47 [gpu_worker.py:232] Available KV cache memory: 10.63 GiB
WARNING 06-19 12:20:48 [kv_cache_utils.py:831] Add 1 padding layers, may waste at most 3.45% KV cache memory
INFO 06-19 12:20:48 [kv_cache_utils.py:871] GPU KV cache size: 79,584 tokens
INFO 06-19 12:20:48 [kv_cache_utils.py:875] Maximum concurrency for 2,048 tokens per request: 38.61x
WARNING 06-19 12:20:48 [utils.py:101] Unable to detect current VLLM config. Defaulting to NHD kv cache layout.
Capturing CUDA graphs: 100% 67/67 [00:42<00:00,  1.56it/s]
INFO 06-19 12:21:31 [gpu_model_runner.py:2083] Graph capturing finished in 43 secs, took 0.55 GiB
INFO 06-19 12:21:31 [core.py:173] init engine (profile, create kv cache, warmup model) took 163.61 seconds
Adding requests:   0% 0/4 [00:00<?, ?it/s]Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
WARNING 06-19 12:21:41 [gemma3_mm.py:124] `do_pan_and_scan=True` has suboptimal results on V1 because of the simplified attention pattern being used.
Adding requests: 100% 4/4 [00:08<00:00,  2.03s/it]
Processed prompts: 100% 4/4 [00:06<00:00,  1.61s/it, est. speed input: 507.36 toks/s, output: 39.79 toks/s]
--------------------------------------------------
Here's a description of the image content:

The image captures a beautiful scene of cherry blossoms in full bloom against a bright blue sky. The blossoms are a delicate shade of pink and are densely packed, creating a visually stunning effect. 

In the background, partially obscured by the branches, is a tall,
--------------------------------------------------
Here's a description of the image content:

The image captures a beautiful scene of cherry blossoms in full bloom against a bright blue sky. The blossoms are a delicate shade of pink and are densely clustered, partially obscuring a tall, modern building (likely a tower) in the background. The branches of the cherry
--------------------------------------------------
Here's a description of the image content:

The image captures a beautiful scene of cherry blossoms in full bloom against a bright blue sky. The blossoms are a delicate shade of pink and are densely packed, creating a visually stunning effect. 

In the background, partially obscured by the branches, is a tall,
--------------------------------------------------
Here's a description of the image content:

The image captures a beautiful scene of cherry blossoms in full bloom against a bright blue sky. The blossoms are a delicate shade of pink and are densely packed, creating a visually stunning effect. 

In the background, partially obscured by the branches, is a tall,
--------------------------------------------------

@Flink-ddd
Copy link
Contributor Author

Hi @houseroad , following up on our discussion about the vision_model.vision_model double prefix.

I've just pushed a new version of the fix that addresses the issue at a more fundamental level, as you suggested.

Instead of patching the load_weights function in siglip.py, I've implemented a WeightsMapper directly within the gemma3_mm.py model loader. This allows for a cleaner and more direct remapping of the weight names from the quantized artifact to what vLLM expects.

I have verified locally that this new approach also successfully resolves the KeyError and allows the model to be served.

The PR should be ready for review and the final CI run whenever you have a moment. Thanks again for your guidance!

@Flink-ddd
Copy link
Contributor Author

Hi @DarkLight1337 Can you merge this codes? I updated this PR about this codes specific description. 😄

@DarkLight1337 DarkLight1337 requested a review from Isotr0py June 21, 2025 14:18
Comment on lines 485 to 489
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"vision_tower.vision_model.": "vision_model.",
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, but Gemma3ForCausalLM is a text-only model, why can it have ViT in checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Isotr0py , thank you again for the very insightful question. It helped me clarify the core of the issue, and my apologies for not making the context clearer in the initial PR description.

You are absolutely correct that gemma3 is a text-only LLM.

This fix is specifically designed to support its multi-modal variants, where gemma-3-4b-it is used as the language backbone and is combined with a separate vision encoder (like SigLIP, which is a ViT). In these common VLM (Vision-Language-Model) architectures, the vision encoder is loaded into a component conventionally named vision_tower.

This leads to the exact problem this PR solves: the model checkpoint contains weight names with the prefix vision_tower.vision_model., while vLLM's internal SiglipVisionModel loader expects the prefix to be just vision_model.. This naming mismatch results in the KeyError.

The WeightsMapper I've implemented in gemma3_mm.py directly and cleanly resolves this by providing a rule to remap vision_tower.vision_model. to vision_model., thus allowing the weights to be loaded correctly. This addresses the bug originally reported in vllm-project/llm-compressor#1546.

I hope this provides the necessary context for the change! Please let me know if you have any further questions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this mapping from Gemma3ForCausalLM to Gemma3ForConditionalGeneration? This mapping in Gemma3ForCausalLM here can be confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Isotr0py ,

Thank you for the excellent and clear guidance. Your feedback was incredibly helpful.

I have just pushed an updated commit that implements your suggestion precisely.

Specifically, I have reverted the previous changes to gemma3.py and have now placed the fix within gemma3_mm.py. I located the existing WeightsMapper in the Gemma3ForConditionalGeneration class and simply added the necessary rule ("vision_tower.vision_model.": "vision_model.") to its orig_to_new_prefix dictionary.

This is a much cleaner solution that correctly places the vision-related logic within the multi-modal class, fully addressing your concern about keeping Gemma3ForCausalLM clean.

The PR should now be ready for another look. Thanks again for your help!

Signed-off-by: Vensenmu <vensenmu@gmail.com>
Signed-off-by: Vensenmu <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the fix-gemma3-keyerror branch from 727cb28 to 4ed2749 Compare June 22, 2025 12:25
Signed-off-by: Vensenmu <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the fix-gemma3-keyerror branch from 4ed2749 to e74d681 Compare June 22, 2025 12:27
@Isotr0py Isotr0py enabled auto-merge (squash) June 22, 2025 12:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 22, 2025
@Flink-ddd
Copy link
Contributor Author

Hi @Isotr0py , thank you for triggering the final CI run!

It seems like the basic-models-test failed with a SafetensorError: MetadataIncompleteBuffer when testing the ArthurZ/Ilama-3.2-1B-auto model, which appears to be unrelated to my changes for the Gemma-3 multi-modal loader.

This looks like it might be a transient CI environment or network issue. Would it be possible to re-run the failed job? Please let me know if there's anything I need to do on my end.

Thank you!

@Isotr0py Isotr0py merged commit 493c275 into vllm-project:main Jun 23, 2025
81 checks passed
juncheoll pushed a commit to juncheoll/vllm that referenced this pull request Jun 23, 2025
…lm-compressor (vllm-project#19643)

Signed-off-by: Vensenmu <vensenmu@gmail.com>
Signed-off-by: juncheoll <th6re8e@naver.com>
fhl2000 pushed a commit to fhl2000/vllm that referenced this pull request Jun 25, 2025
…lm-compressor (vllm-project#19643)

Signed-off-by: Vensenmu <vensenmu@gmail.com>
Signed-off-by: fhl <2410591650@qq.com>
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
…lm-compressor (vllm-project#19643)

Signed-off-by: Vensenmu <vensenmu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants