Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Shape error encountered in speculative decoding when enable_lora=True #4872

Open
mitchellstern opened this issue May 17, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@mitchellstern
Copy link

Your current environment

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.11.6 (main, Oct  3 2023, 01:26:22) [Clang 17.0.1 ] (64-bit runtime)
Python platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             12
On-line CPU(s) list:                0-11
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Gold 5317 CPU @ 3.00GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 1
Core(s) per socket:                 12
Socket(s):                          1
Stepping:                           6
BogoMIPS:                           6000.17
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush acpi mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti intel_ppin ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves umip pku ospke gfni vaes vpclmulqdq rdpid md_clear flush_l1d
Hypervisor vendor:                  Xen
Virtualization type:                full
L1d cache:                          576 KiB (12 instances)
L1i cache:                          384 KiB (12 instances)
L2 cache:                           15 MiB (12 instances)
L3 cache:                           216 MiB (12 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] onnxruntime==1.17.3
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[pip3] vllm_nccl_cu12==2.18.1.0.4.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.2
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	0-11	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I'd like to try out the recently added speculative decoding features. However, I'm encountering a shape error at the following line of code when my model has enable_lora=True, even if I'm not using a LoRA adapter in my request:

target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)

Minimal reproduction with enable_lora=True (encounters a shape error):

from vllm import LLM

llm = LLM(
    model="01-ai/Yi-6B-Chat",
    enable_lora=True,
    use_v2_block_manager=True,
    speculative_model="[ngram]",
    num_speculative_tokens=5,
    ngram_prompt_lookup_min=4,
    ngram_prompt_lookup_max=8,
)

print(llm.generate("Hello, my name is"))
INFO 05-17 00:49:46 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='01-ai/Yi-6B-Chat', speculative_config=SpeculativeConfig(draft_model='[ngram]', num_spec_tokens=5), tokenizer='01-ai/Yi-6B-Chat', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=01-ai/Yi-6B-Chat)
INFO 05-17 00:49:47 utils.py:660] Found nccl from library /home/paperspace/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-17 00:49:48 selector.py:27] Using FlashAttention-2 backend.
INFO 05-17 00:49:48 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-17 00:49:51 model_runner.py:175] Loading model weights took 11.2944 GB
INFO 05-17 00:49:51 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-17 00:49:53 model_runner.py:175] Loading model weights took 11.2935 GB
INFO 05-17 00:49:54 gpu_executor.py:114] # GPU blocks: 24329, # CPU blocks: 4096
INFO 05-17 00:49:55 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-17 00:49:55 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-17 00:50:01 model_runner.py:1017] Graph capturing finished in 5 secs.
INFO 05-17 00:50:02 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-17 00:50:02 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-17 00:50:08 model_runner.py:1017] Graph capturing finished in 6 secs.
Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s]WARNING 05-17 00:50:08 multi_step.py:57] Prompt logprob is not supported by multi step workers. (e.g., speculative decode uses multi step workers).
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/paperspace/vllm_speculative_decoding_bug.py", line 13, in <module>
[rank0]:     print(llm.generate("Hello, my name is"))
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 219, in generate
[rank0]:     return self._run_engine(use_tqdm)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 247, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/engine/llm_engine.py", line 595, in step
[rank0]:     output = self.model_executor.execute_model(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/executor/gpu_executor.py", line 122, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py", line 208, in execute_model
[rank0]:     return self._run_speculative_decoding_step(execute_model_req)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.cache/bazel/_bazel_paperspace/d3aa00cc2e101179676388fec04b7ba3/external/python_x86_64-unknown-linux-gnu/lib/python3.11/contextlib.py", line 81, in inner
[rank0]:     return func(*args, **kwds)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/spec_decode/spec_decode_worker.py", line 252, in _run_speculative_decoding_step
[rank0]:     proposal_scores = self.scorer.score_proposals(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.cache/bazel/_bazel_paperspace/d3aa00cc2e101179676388fec04b7ba3/external/python_x86_64-unknown-linux-gnu/lib/python3.11/contextlib.py", line 81, in inner
[rank0]:     return func(*args, **kwds)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py", line 87, in score_proposals
[rank0]:     all_tokens, all_probs, spec_logprobs = self._contract_batch(
[rank0]:                                            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/paperspace/.venv/lib/python3.11/site-packages/vllm/spec_decode/batch_expansion.py", line 172, in _contract_batch
[rank0]:     target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: shape '[1, 6, 64000]' is invalid for input of size 385536
Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s]

Minimal reproduction without enable_lora=True (runs without error):

from vllm import LLM

llm = LLM(
    model="01-ai/Yi-6B-Chat",
    # enable_lora=True,
    use_v2_block_manager=True,
    speculative_model="[ngram]",
    num_speculative_tokens=5,
    ngram_prompt_lookup_min=4,
    ngram_prompt_lookup_max=8,
)

print(llm.generate("Hello, my name is"))
INFO 05-17 00:55:39 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='01-ai/Yi-6B-Chat', speculative_config=SpeculativeConfig(draft_model='[ngram]', num_spec_tokens=5), tokenizer='01-ai/Yi-6B-Chat', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=01-ai/Yi-6B-Chat)
INFO 05-17 00:55:40 utils.py:660] Found nccl from library /home/paperspace/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-17 00:55:41 selector.py:27] Using FlashAttention-2 backend.
INFO 05-17 00:55:41 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-17 00:55:43 model_runner.py:175] Loading model weights took 11.2905 GB
INFO 05-17 00:55:43 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-17 00:55:45 model_runner.py:175] Loading model weights took 11.2896 GB
INFO 05-17 00:55:46 gpu_executor.py:114] # GPU blocks: 24415, # CPU blocks: 4096
INFO 05-17 00:55:48 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-17 00:55:48 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-17 00:55:53 model_runner.py:1017] Graph capturing finished in 5 secs.
INFO 05-17 00:55:54 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-17 00:55:54 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-17 00:55:59 model_runner.py:1017] Graph capturing finished in 5 secs.
Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s]WARNING 05-17 00:55:59 multi_step.py:57] Prompt logprob is not supported by multi step workers. (e.g., speculative decode uses multi step workers).
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
[RequestOutput(request_id=0, prompt='Hello, my name is', prompt_token_ids=[29915, 97, 826, 1815, 620], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' Jake and I’ll be approaching AI to paint the image below. My prompt', token_ids=[26758, 597, 616, 59629, 928, 629, 20098, 13821, 592, 7718, 567, 2728, 2723, 98, 2439, 9187], cumulative_logprob=-66.05616342741996, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1715907359.5151591, last_token_time=1715907359.5151591, first_scheduled_time=1715907359.5172217, first_token_time=1715907359.5404856, time_in_queue=0.002062559127807617, finished_time=1715907359.8753097), lora_request=None)]
@mitchellstern mitchellstern added the bug Something isn't working label May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant