Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vllm engine. #40

Merged
merged 12 commits into from
Jan 17, 2024
Merged

Support vllm engine. #40

merged 12 commits into from
Jan 17, 2024

Conversation

Isotr0py
Copy link
Contributor

@Isotr0py Isotr0py commented Jan 14, 2024

Related issue: #39

  • 支持vllm推理后端

TODO:

  • 支持流式输出 (AsyncLLMEngine) -> Update: need to fix Fixed
  • 支持非流式输出 (LLM)
  • 测试GPTQ/AWQ量化模型推理
  • 添加 requirements -> Update: conflict with transformers==4.33.2, won't add to requirements.txt

尚未完成所有测试,先提个draft
Done.

Install

Run:

# cu121
pip3 install https://github.com/vllm-project/vllm/releases/download/v0.2.7/vllm-0.2.7-cp310-cp310-manylinux1_x86_64.whl

or

# cu118
pip3 install https://github.com/vllm-project/vllm/releases/download/v0.2.7/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl

before running pip3 install transformers==4.33.2 sentencepiece xformers

@pipixia244 pipixia244 added enhancement New feature or request server This issue is about Sakura Server API labels Jan 14, 2024
utils/model.py Outdated Show resolved Hide resolved
utils/model.py Outdated Show resolved Hide resolved
utils/cli.py Show resolved Hide resolved
utils/model.py Outdated Show resolved Hide resolved
utils/model.py Outdated Show resolved Hide resolved
@kurikomoe
Copy link
Collaborator

I will put a pr to group arguments in different groups after this pr merged. Currently it's kind of messy.
image

@Isotr0py
Copy link
Contributor Author

Isotr0py commented Jan 14, 2024

Test results

  • Sakura-13B-LNovel-v0_8-4bit GPTQ model works with tensor_parallel_size=1.
  • Sakura-7B 4bit AWQ model works with tensor_parallel_size=2.

Benchmark (Sakura-13B-LNovel-v0_8-4bit)

  • vllm (tensor_parallel_size=1, enforce_eager=True, T4 GPU)
INFO 01-14 14:30:33 llm_engine.py:706] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 27.4 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 13.2%, CPU KV cache usage: 0.0%
INFO 01-14 14:30:38 llm_engine.py:706] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 27.0 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 15.8%, CPU KV cache usage: 0.0%
INFO 01-14 14:30:43 llm_engine.py:706] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 26.7 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 18.3%, CPU KV cache usage: 0.0%
Processed prompts: 100%|██████████████████████████| 1/1 [00:19<00:00, 19.72s/it]
2024-01-14 14:30:47 07e90ad7a5a0 utils.model[4917] INFO Output generated in 19.72 seconds (25.96 tokens/s, 512 tokens, context 505 tokens)
  • autogptq (distributed inference, T4 GPU*2)
2024-01-14 14:47:21 07e90ad7a5a0 utils.model[4972] INFO Output generated in 140.95 seconds (3.63 tokens/s, 512 tokens, context 505 tokens)

Known issues

  • Sakura-13B-LNovel-v0_8-4bit GPTQ model doesn't work with tensor_parallel_size=2.
    Because input_size_per_partition % (quant_config.group_size * tensor_parallel_size) != 0 is not met, where input_size_per_partition=13696 and quantization_config.group_size=128

StreamOutput doesn't work yet, will be fixed soon.

@Isotr0py Isotr0py marked this pull request as ready for review January 16, 2024 06:12
@pipixia244
Copy link
Collaborator

pipixia244 commented Jan 16, 2024

So is there any solution about the fake stream output?
If the stream model only output after all inference is done then there is no sense.

Maybe we can implement the true stream output referring to Qwen's implementation or vllm openai api server implementation.
I'll soon on it.

@kurikomoe
Copy link
Collaborator

kurikomoe commented Jan 16, 2024

vllm 0.2.7 requires pydantic==1.10.13, but you have pydantic 2.5.3 which is incompatible.
vllm 0.2.7 requires transformers>=4.36.0, but you have transformers 4.33.2 which is incompatible.

AFAIK, Sakura-13B-LNovel-v0_8-4bit is based on Baichuan which needs transformer == 4.33.2, or we need to change some code like sp_model to fix the compatibility issue.

As for the streaming output and nested_async, give me some time to make this pr run on my PC. The plain Sakura-13B-LNovel-v0_8-4bit with the following command failed to start. Maybe something related to the torch version. I will try the 0.8 awq later.

python3 server.py --listen 0.0.0.0:5000 --trust_remote_code --model_name_or_path ./models/Sakura-13B-LNovel-v0_8-4bit  --model_version 0.8 --no-auth --log debug  --vllm

ok my fault, just because 3090 out of memory when try to run Sakura-13B-LNovel-v0_8-4bit with --vllm

@Isotr0py
Copy link
Contributor Author

It's strange that 3090 will OOM. It seems that it's because --use_gptq_model missing.
My cmd to run Sakura-13B-LNovel-v0_8-4bit on a 15G T4:

python server.py \
    --model_name_or_path SakuraLLM/Sakura-13B-LNovel-v0_8-4bit \
    --vllm \
    --use_gptq_model \
    --model_version 0.8 \
    --trust_remote_code \
    --no-auth \
    --tensor_parallel_size 1 \
    --enforce_eager \
    --gpu_memory_utilization 0.95

QQ截图20240116223559

@kurikomoe
Copy link
Collaborator

kurikomoe commented Jan 16, 2024

It's strange that 3090 will OOM. It seems that it's because --use_gptq_model missing. My cmd to run Sakura-13B-LNovel-v0_8-4bit on a 15G T4:

python server.py \
    --model_name_or_path SakuraLLM/Sakura-13B-LNovel-v0_8-4bit \
    --vllm \
    --use_gptq_model \
    --model_version 0.8 \
    --trust_remote_code \
    --no-auth \
    --tensor_parallel_size 1 \
    --enforce_eager \
    --gpu_memory_utilization 0.95

QQ截图20240116223559

ok, --use_gptq_model works.
Then, I need to think of a method to validate the model_name/version/quant/quant_methods against the command-line options in the next cli patch.

@Isotr0py
Copy link
Contributor Author

OK, the fake stream output problem should be solved now.

utils/model.py Outdated Show resolved Hide resolved
@pipixia244
Copy link
Collaborator

Then, I need to think of a method to validate the model_name/version/quant/quant_methods against the command-line options in the next cli patch.

This can be very helpful for those who don't know much about how the params work.

utils/model.py Show resolved Hide resolved
utils/model.py Show resolved Hide resolved
utils/model.py Outdated Show resolved Hide resolved
utils/model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@kurikomoe kurikomoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, tested on 0.8 4bit gptq

Still, we need to solve the transformer == 4.33.2 issue soon or later.

@pipixia244 pipixia244 merged commit 0adef2b into SakuraLLM:main Jan 17, 2024
3 checks passed
@pipixia244
Copy link
Collaborator

I'll update README.md and pyinstaller settings soon.

@Isotr0py Isotr0py deleted the vllm branch January 17, 2024 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request server This issue is about Sakura Server API
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants