Skip to content

[Bug]: Enable LORA on Version 0.9.1 and RTX 5090 causes an issue #19693

Open
@bely66

Description

@bely66

Your current environment

The output of vllm serve models/Llama-3.1-8B --max-model-len 9000 --quantization bitsandbytes --load-format bitsandbytes --enable-lora
INFO 06-16 07:20:22 [__init__.py:244] Automatically detected platform cuda.
INFO 06-16 07:20:26 [api_server.py:1287] vLLM API server version 0.9.1
INFO 06-16 07:20:26 [cli_args.py:309] non-default args: {'model': 'models/Llama-3.1-8B', 'max_model_len': 9000, 'quantization': 'bitsandbytes', 'load_format': 'bitsandbytes', 'enable_lora': True}
INFO 06-16 07:20:32 [config.py:823] This model supports multiple tasks: {'classify', 'score', 'embed', 'generate', 'reward'}. Defaulting to 'generate'.
WARNING 06-16 07:20:32 [config.py:931] bitsandbytes quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 06-16 07:20:32 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=2048.
WARNING 06-16 07:20:34 [env_override.py:17] NCCL_CUMEM_ENABLE is set to 0, skipping override. This may increase memory overhead with cudagraph+allreduce: https://github.com/NVIDIA/nccl/issues/1234
INFO 06-16 07:20:36 [__init__.py:244] Automatically detected platform cuda.
INFO 06-16 07:20:38 [core.py:455] Waiting for init message from front-end.
INFO 06-16 07:20:38 [core.py:70] Initializing a V1 LLM engine (v0.9.1) with config: model='models/Llama-3.1-8B', speculative_config=None, tokenizer='models/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=9000, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=models/Llama-3.1-8B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
WARNING 06-16 07:20:39 [utils.py:2737] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x70623c5342c0>
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 06-16 07:20:39 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 06-16 07:20:39 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
INFO 06-16 07:20:39 [gpu_model_runner.py:1595] Starting to load model models/Llama-3.1-8B...
INFO 06-16 07:20:39 [gpu_model_runner.py:1600] Loading model from scratch...
INFO 06-16 07:20:39 [cuda.py:252] Using Flash Attention backend on V1 engine.
INFO 06-16 07:20:40 [bitsandbytes_loader.py:454] Loading weights with BitsAndBytes quantization. May take a while ...
INFO 06-16 07:20:41 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 06-16 07:20:42 [gpu_model_runner.py:1624] Model loading took 5.4363 GiB and 2.171765 seconds
INFO 06-16 07:20:49 [backends.py:460] vLLM's torch.compile cache is disabled.
INFO 06-16 07:20:49 [backends.py:472] Dynamo bytecode transform time: 6.32 s
INFO 06-16 07:20:49 [backends.py:161] Cache the graph of shape None for later use
INFO 06-16 07:20:53 [backends.py:173] Compiling a graph for general shape takes 3.57 s
INFO 06-16 07:20:56 [monitor.py:34] torch.compile takes 9.89 s in total
INFO 06-16 07:20:56 [gpu_worker.py:227] Available KV cache memory: 22.24 GiB
INFO 06-16 07:20:57 [kv_cache_utils.py:715] GPU KV cache size: 182,224 tokens
INFO 06-16 07:20:57 [kv_cache_utils.py:719] Maximum concurrency for 9,000 tokens per request: 20.23x
ERROR 06-16 07:20:57 [core.py:515] EngineCore failed to start.
ERROR 06-16 07:20:57 [core.py:515] Traceback (most recent call last):
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 506, in run_engine_core
ERROR 06-16 07:20:57 [core.py:515]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 390, in __init__
ERROR 06-16 07:20:57 [core.py:515]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 83, in __init__
ERROR 06-16 07:20:57 [core.py:515]     self._initialize_kv_caches(vllm_config)
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 168, in _initialize_kv_caches
ERROR 06-16 07:20:57 [core.py:515]     self.model_executor.initialize_from_config(kv_cache_configs)
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 66, in initialize_from_config
ERROR 06-16 07:20:57 [core.py:515]     self.collective_rpc("compile_or_warm_up_model")
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-16 07:20:57 [core.py:515]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-16 07:20:57 [core.py:515]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2671, in run_method
ERROR 06-16 07:20:57 [core.py:515]     return func(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 261, in compile_or_warm_up_model
ERROR 06-16 07:20:57 [core.py:515]     self.model_runner.capture_model()
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 2040, in capture_model
ERROR 06-16 07:20:57 [core.py:515]     self._dummy_run(num_tokens, skip_attn=skip_attn)
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
ERROR 06-16 07:20:57 [core.py:515]     return func(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1847, in _dummy_run
ERROR 06-16 07:20:57 [core.py:515]     outputs = model(
ERROR 06-16 07:20:57 [core.py:515]               ^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
ERROR 06-16 07:20:57 [core.py:515]     return self._call_impl(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
ERROR 06-16 07:20:57 [core.py:515]     return forward_call(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 581, in forward
ERROR 06-16 07:20:57 [core.py:515]     model_output = self.model(input_ids, positions, intermediate_tensors,
ERROR 06-16 07:20:57 [core.py:515]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 246, in __call__
ERROR 06-16 07:20:57 [core.py:515]     model_output = self.forward(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 368, in forward
ERROR 06-16 07:20:57 [core.py:515]     def forward(
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 374, in __call__
ERROR 06-16 07:20:57 [core.py:515]     return super().__call__(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
ERROR 06-16 07:20:57 [core.py:515]     return self._call_impl(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
ERROR 06-16 07:20:57 [core.py:515]     return forward_call(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 899, in _fn
ERROR 06-16 07:20:57 [core.py:515]     return fn(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 848, in call_wrapped
ERROR 06-16 07:20:57 [core.py:515]     return self._wrapped_call(self, *args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 424, in __call__
ERROR 06-16 07:20:57 [core.py:515]     raise e
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 411, in __call__
ERROR 06-16 07:20:57 [core.py:515]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
ERROR 06-16 07:20:57 [core.py:515]     return self._call_impl(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
ERROR 06-16 07:20:57 [core.py:515]     return forward_call(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "<eval_with_key>.66", line 787, in forward
ERROR 06-16 07:20:57 [core.py:515]     submod_0 = self.submod_0(l_input_ids_, s72, l_self_modules_embed_tokens_punica_wrapper_embeddings_indices, l_self_modules_embed_tokens_lora_a_stacked_2d, l_self_modules_embed_tokens_modules_base_layer_parameters_weight_, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_token_lora_mapping, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_token_indices_sorted_by_lora_ids, l_self_modules_embed_tokens_lora_b_stacked, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_num_tokens_per_lora, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_lora_token_start_loc, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_active_lora_ids, l_self_modules_embed_tokens_punica_wrapper_token_mapping_meta_no_lora_flag_cpu, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_modules_base_layer_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_modules_base_layer_parameters_weight_bnb_shard_offsets, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_0_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_1_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_2_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_0_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_1_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_2_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_);  l_input_ids_ = l_self_modules_embed_tokens_punica_wrapper_embeddings_indices = l_self_modules_embed_tokens_lora_a_stacked_2d = l_self_modules_embed_tokens_modules_base_layer_parameters_weight_ = l_self_modules_embed_tokens_lora_b_stacked = l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_modules_base_layer_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_modules_base_layer_parameters_weight_bnb_shard_offsets = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_0_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_1_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_a_stacked_2_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_0_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_1_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_lora_b_stacked_2_ = None
ERROR 06-16 07:20:57 [core.py:515]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/cuda_piecewise_backend.py", line 152, in __call__
ERROR 06-16 07:20:57 [core.py:515]     return entry.runnable(*args)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 899, in _fn
ERROR 06-16 07:20:57 [core.py:515]     return fn(*args, **kwargs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1240, in forward
ERROR 06-16 07:20:57 [core.py:515]     return compiled_fn(full_args)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 357, in runtime_wrapper
ERROR 06-16 07:20:57 [core.py:515]     all_outs = call_func_at_runtime_with_args(
ERROR 06-16 07:20:57 [core.py:515]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
ERROR 06-16 07:20:57 [core.py:515]     out = normalize_as_list(f(args))
ERROR 06-16 07:20:57 [core.py:515]                             ^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 529, in wrapper
ERROR 06-16 07:20:57 [core.py:515]     return compiled_fn(runtime_args)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 583, in __call__
ERROR 06-16 07:20:57 [core.py:515]     return self.current_callable(inputs)
ERROR 06-16 07:20:57 [core.py:515]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 2708, in run
ERROR 06-16 07:20:57 [core.py:515]     out = model(new_inputs)
ERROR 06-16 07:20:57 [core.py:515]           ^^^^^^^^^^^^^^^^^
ERROR 06-16 07:20:57 [core.py:515]   File "/tmp/torchinductor_root/l4/cl4vp5tv7aoq7ukwm7kybawurntf6akq2qlc6h2k7cn4672bpmxl.py", line 442, in call
ERROR 06-16 07:20:57 [core.py:515]     assert_size_stride(arg0_1, (2048, ), (1, ))
ERROR 06-16 07:20:57 [core.py:515] AssertionError: expected size 512==2048, stride 1==1 at dim=0
ERROR 06-16 07:20:57 [core.py:515] This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
ERROR 06-16 07:20:57 [core.py:515] Use torch.library.opcheck to test your custom op.
ERROR 06-16 07:20:57 [core.py:515] See https://pytorch.org/docs/stable/library.html#torch.library.opcheck

🐛 Describe the bug

Okay so when i run the above command without --enable-lora
vllm serve models/Llama-3.1-8B --max-model-len 9000 --quantization bitsandbytes --load-format bitsandbytes --enable-lora

with any configuration it works

Once --enable-lora is added it goes crazy with this error

Note: it's not related to me providing a lora adapters to it, because i did using:
--lora-modules adapter_name=adapter_path
gives the. same issue too

Update: Added --enforce-eager now it can't see the kernel image at all

Error: torch.AcceleratorError: CUDA error: no kernel image is available for execution on the device

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions