Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Load FP8 kv-cache scaling factors from checkpoints #4893

Merged
merged 6 commits into from
May 22, 2024

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented May 17, 2024

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
Specifically,

  1. We now support --kv-cache-dtype {auto, fp8, fp8_e4m3, fp8_e5m2}. auto=fp16 or bf16 and fp8=fp8_e4m3.
  2. If the checkpoint is in FP16, then kv-cache scaling factors can only be loaded via --quantization-param-path; otherwise kv-scale is always 1 regardless fp8_e4m3 or fp8_e5m2.
  3. If the checkpoint is in FP8 (e4m3) AND --kv-cache-dtype {fp8, fp8_e4m3}, kv_scale will be loaded from the checkpoint (if the field presents).

Here is a simple benchmark on a single NVIDIA L4 GPU:

  • FP16 model: meta-llama/Meta-Llama-3-8B-Instruct
  • FP8 model: nm-testing/Meta-Llama-3-8B-Instruct-FP8 (no kv-scale so use 1.0 for now).
  • QPS: 1
  • Total requests: 30
  • Average prompt length: 512
  • Average decoding length: 256 (ignore_eos is not enabled so the actual length may be differ).
  • Temperature: 0
Model Dtype kv-cache Dtype GPU blocks TTFT (ms) ITL (ms) e2e (s)
FP16 FP16 1470 219.0 90.5 15.7
FP16 FP8_e5m2 2940 219.3 88.2 14.9
FP8_e4m3 FP16 4784 148.2 51.3 9.3
FP8_e4m3 FP8_e4m3 9568 147.4 50.1 8.1

@robertgshaw2-neuralmagic @tlrmchlsmth can you help update nm-testing/Meta-Llama-3-8B-Instruct-FP8 to include kv-cache scaling factors so that we could test it? Thanks!

TODO

  • Unit test.

Also cc @pcmoritz @Yard1

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! As a next step, we should create an FP8 checkpoint with appropriate scales and make sure the accuracy looks good :)

@tlrmchlsmth
Copy link
Contributor

@robertgshaw2-neuralmagic @tlrmchlsmth can you help update nm-testing/Meta-Llama-3-8B-Instruct-FP8 to include kv-cache scaling factors so that we could test it? Thanks!

I'm pinging some NeuralMagic folks to see if we can get those models updated

@mgoin
Copy link
Collaborator

mgoin commented May 20, 2024

Hey y'all, I made this model as a test https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV. I haven't tested the accuracy yet but it should be sufficient for a smoke test @comaniac @tlrmchlsmth

@comaniac
Copy link
Contributor Author

Hey y'all, I made this model as a test https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV. I haven't tested the accuracy yet but it should be sufficient for a smoke test @comaniac @tlrmchlsmth

Thanks! I'll use this model for testing in this PR and post back soon.

@comaniac
Copy link
Contributor Author

Output comparison (using the example prompts in the tests):

  • meta-llama/Meta-Llama-3-8B-Instruct
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
  • nm-testing/Meta-Llama-3-8B-Instruct-FP8 with FP16 kv-cache
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
  • nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV with FP8 kv-cache and kv-scale=1.0
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in terms of processing information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to life',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o'
  • nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV with FP8 kv-cache and in-checkpoint kv-scale
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system made up of several basic components that work together to enable it to',
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'

Although we cannot judge the model quality only by these simple prompts, it should verify that the kv-scale from the checkpoint is loaded correctly.

@comaniac
Copy link
Contributor Author

Side note: Since #4907, the flash-attn is used for both prefill and decoding. However, flash-attn doesn't support FP8 input, so now when FP8 kv-cache is enabled, vLLM will enforce to use xFormers backend (which uses paged attention kernel in decoding).

@comaniac
Copy link
Contributor Author

All done. Final note: Per offline discussion with @mgoin, we should accept the checkpoint with HuggingFace model compatible format. In other words, the kv_scale should have the weight name model.layers.0.self_attn.kv_scale. On the other hand, the implementation in this PR puts kv_scale in model.layers.0.self_attn.attn.kv_scale to hide the implementation details from model implementer. We now let the model specific weight loader to deal with this re-mapping, and thus now only Llama and Mixtral could correctly load the kv-scale.

Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a solid test, I think the tradeoff with model support for checkpoint loading is fair. This falls in line with how we deal with other module replacements.

@pcmoritz pcmoritz merged commit a3a73ab into vllm-project:main May 22, 2024
61 checks passed
@comaniac comaniac deleted the fp8-kv-2 branch May 22, 2024 23:22
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
triple-Mu pushed a commit to CC-LLM/vllm that referenced this pull request Jun 5, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants