Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Phi-2 LoRA support #4886

Merged
merged 15 commits into from
May 21, 2024
Merged

[Model] Add Phi-2 LoRA support #4886

merged 15 commits into from
May 21, 2024

Conversation

Isotr0py
Copy link
Contributor

@Isotr0py Isotr0py commented May 17, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #4141
FIX #3562

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@Isotr0py
Copy link
Contributor Author

Test

from vllm import LLM
from vllm import SamplingParams
from vllm.lora.request import LoRARequest

llm = LLM("/data/LLM-model/phi-2", enable_lora=True)

sql_lora_path = "/data/PEFT-LoRA/phi-2/phi-2-universal-NER"


prompts = [
    "<|im_start|>human\nText: Mit Patel here from India<|im_end|>\n<|im_start|>gpt\nI've read this text.<|im_end|>\n<|im_start|>human\nwhat is a name of the person in the text?<|im_end|>\n<|im_start|>gpt:\n",
]

sampling_params = SamplingParams(
    temperature=0.8, top_p=0.95, max_tokens=64, stop="<|im_end|>"
)

outputs = llm.generate(
    prompts, sampling_params, lora_request=LoRARequest("phi_adapter", 1, sql_lora_path)
)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Output

INFO 05-17 21:22:23 llm_engine.py:103] Initializing an LLM engine (v0.4.2) with config: model='/data/LLM-model/phi-2', speculative_config=None, tokenizer='/data/LLM-model/phi-2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=/data/LLM-model/phi-2)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
WARNING 05-17 21:22:23 cpu_executor.py:112] float16 is not supported on CPU, casting to bfloat16.
WARNING 05-17 21:22:23 cpu_executor.py:115] CUDA graph is not supported on CPU, fallback to the eager mode.
WARNING 05-17 21:22:23 cpu_executor.py:142] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
WARNING 05-17 21:22:23 punica.py:14] punica LoRA kernels require a GPU to run. But you are using the CPU version vLLM
INFO 05-17 21:22:24 selector.py:52] Using Torch SDPA backend.
INFO 05-17 21:22:35 cpu_executor.py:71] # CPU blocks: 819
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [09:09<00:00, 549.94s/it, Generation Speed: 0.02 toks/s]
Prompt: "<|im_start|>human\nText: Mit Patel here from India<|im_end|>\n<|im_start|>gpt\nI've read this text.<|im_end|>\n<|im_start|>human\nwhat is a name of the person in the text?<|im_end|>\n<|im_start|>gpt:\n", Generated text: 'Mit Patel\n'

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Consider adding a test similar to https://github.com/vllm-project/vllm/blob/main/tests/lora/test_llama.py? Also cc @Yard1

@Isotr0py
Copy link
Contributor Author

Isotr0py commented May 17, 2024

OK, but I haven't tested this on punica kernel yet, and only tested it on the cpu kernel proposed in another PR: #4830. I will test it in GPU environment later.

It seems that Phi-2 has a different vocab_size compared to existing punica dimension in bgmv_config.h, I wonder how we can determine the punica dimension we should add in `bgmv_config.h since it's a little different from original model vocab_size.

@rkooo567
Copy link
Collaborator

cc @Yard1 for the question above (who's more familiar with punica kernel)

@Isotr0py
Copy link
Contributor Author

Well, this works on RTX3080 Ti with punica kernel so we don't need to modify the punica code anymore :)

INFO 05-18 01:19:13 llm_engine.py:103] Initializing an LLM engine (v0.4.2) with config: model='/data/LLM-model/phi-2', speculative_config=None, tokenizer='/data/LLM-model/phi-2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=/data/LLM-model/phi-2)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
INFO 05-18 01:19:14 selector.py:37] Using FlashAttention-2 backend.
INFO 05-18 01:19:16 model_runner.py:145] Loading model weights took 5.1933 GB
INFO 05-18 01:19:18 gpu_executor.py:83] # GPU blocks: 948, # CPU blocks: 819
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.72it/s, Generation Speed: 27.19 toks/s]
Prompt: "<|im_start|>human\nText: Mit Patel here from India<|im_end|>\n<|im_start|>gpt\nI've read this text.<|im_end|>\n<|im_start|>human\nwhat is a name of the person in the text?<|im_end|>\n<|im_start|>gpt:\n", Generated text: 'Mit Patel\n'

I will add the remaining test later.

@rkooo567
Copy link
Collaborator

sounds great! It is awesome we just need to change supported modules specification

@Yard1 Yard1 self-requested a review May 17, 2024 20:08
@rkooo567 rkooo567 self-assigned this May 18, 2024
@Isotr0py Isotr0py marked this pull request as ready for review May 19, 2024 06:27
@Isotr0py
Copy link
Contributor Author

Emmm, it's strange that the CI worker OOM during running test_phi.py in lora-test-3:

=========================== short test summary info ============================
FAILED lora/test_phi.py::test_phi2_lora - RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
FAILED lora/test_punica.py::test_lora_a_extra_shapes[11806239938951-1-256-float16] - RuntimeError: CUDA generator expects graph capture to be underway, but the current stream is not capturing.
FAILED lora/test_punica.py::test_lora_a_extra_shapes[11806239938951-1-1024-float16] - RuntimeError: CUDA generator expects graph capture to be underway, but the current stream is not capturing.

Is there any way to solve this?

@Yard1
Copy link
Collaborator

Yard1 commented May 19, 2024

@Isotr0py try reducing max_loras in the test to 2

@Isotr0py
Copy link
Contributor Author

It seems that enforce_eager=True can solve the OOM problem, but should we use this in test CI?

@Yard1
Copy link
Collaborator

Yard1 commented May 20, 2024

@Isotr0py we can keep it but can you add a comment whenever you use enforce_eager in test to clarify why we set it there?

@rkooo567
Copy link
Collaborator

yeah technically enforce_eager=False case should be already tested

@Isotr0py
Copy link
Contributor Author

I have added a comment to mark that enforce_eager=True is used for test_phi2 to reduce memory usage for CI test.

@rkooo567 rkooo567 merged commit f12c3b5 into vllm-project:main May 21, 2024
61 checks passed
@Isotr0py Isotr0py deleted the phi2_lora branch May 21, 2024 07:01
@rkooo567
Copy link
Collaborator




FAILED lora/test_phi.py::test_phi2_lora - ValueError: Head size 80 is not supported by FlashAttention. Supported head sizes are: [32, 64, 96, 128, 160, 192, 224, 256].
 

<br class="Apple-interchange-newline">

looks like the test is failing with this error.

@Isotr0py
Copy link
Contributor Author

@rkooo567 I have created #4944 to fix this, and LoRA-test-3 passed in the fix PR.

However, the lora-long-context-distributed test failed in that PR, can you have a look at that?

@rkooo567
Copy link
Collaborator

The failure looks kind of unrelated... let me retry the test

tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Phi2 LoRA support [New Model]: Phi-2 support for LoRA
3 participants