In [1]:
import os
from typing import List, Optional, Tuple

from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest

from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")


In [2]:
"""
This example shows how to use the multi-LoRA functionality
for offline inference.

Requires HuggingFace credentials for access to Llama2.
"""

from typing import List, Optional, Tuple

from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


In [3]:


def create_test_prompts(
        lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.

    2 requests for base model, 4 requests for the LoRA. We define 2
    different LoRA adapters (using the same model for demo purposes).
    Since we also set `max_loras=1`, the expectation is that the requests
    with the second LoRA adapter will be ran after all requests with the
    first adapter have finished.
    """
    return [
        ("A robot may not injure a human being",
         SamplingParams(temperature=0.0,
                        logprobs=1,
                        prompt_logprobs=1,
                        max_tokens=128), None),
        ("To be or not to be,",
         SamplingParams(temperature=0.8,
                        top_k=5,
                        presence_penalty=0.2,
                        max_tokens=128), None),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
            SamplingParams(temperature=0.0,
                           logprobs=1,
                           prompt_logprobs=1,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora2", 2, lora_path)),
        (
            "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
            SamplingParams(n=3,
                           best_of=3,
                           use_beam_search=True,
                           temperature=0,
                           max_tokens=128,
                           stop_token_ids=[32003]),
            LoRARequest("sql-lora", 1, lora_path)),
    ]



In [4]:

def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)
    return request_outputs


def initialize_engine() -> LLMEngine:
    """Initialize the LLMEngine."""
    # max_loras: controls the number of LoRAs that can be used in the same
    #   batch. Larger numbers will cause higher memory usage, as each LoRA
    #   slot requires its own preallocated tensor.
    # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
    #   numbers will cause higher memory usage. If you know that all LoRAs will
    #   use the same rank, it is recommended to set this as low as possible.
    # max_cpu_loras: controls the size of the CPU LoRA cache.
    engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
                             enable_lora=True,
                             max_loras=1,
                             max_lora_rank=8,
                             max_cpu_loras=2,
                             max_num_seqs=256)
    return LLMEngine.from_engine_args(engine_args)


def main():
    """Main function that sets up and runs the prompt processing."""
    engine = initialize_engine()
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
    test_prompts = create_test_prompts(lora_path)
    return process_requests(engine, test_prompts)


if __name__ == '__main__':
    result = main()

INFO 06-09 22:11:53 llm_engine.py:161] Initializing an LLM engine (v0.4.3) with config: model='meta-llama/Llama-2-7b-hf', speculative_config=None, tokenizer='meta-llama/Llama-2-7b-hf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Llama-2-7b-hf)




INFO 06-09 22:11:54 weight_utils.py:207] Using model weights format ['*.safetensors']
INFO 06-09 22:11:56 model_runner.py:146] Loading model weights took 12.5562 GB
INFO 06-09 22:11:57 gpu_executor.py:83] # GPU blocks: 3786, # CPU blocks: 512
INFO 06-09 22:11:59 model_runner.py:854] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 06-09 22:11:59 model_runner.py:858] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 06-09 22:12:03 model_runner.py:924] Graph capturing finished in 4 secs.


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

RequestOutput(request_id=2, prompt='[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]', prompt_token_ids=[1, 32000, 29871, 14350, 263, 3758, 2346, 304, 1234, 278, 1139, 2729, 373, 278, 1591, 10938, 29889, 13, 13, 3030, 29901, 14602, 10911, 1591, 29918, 978, 29918, 29955, 29946, 313, 983, 29877, 21748, 29892, 4799, 637, 21748, 29897, 13, 13, 1139, 29901, 4408, 278, 306, 5454, 29949, 363, 301, 309, 549, 705, 6121, 4799, 637, 29871, 32001, 259, 32002], prompt_logprobs=[None, {32000: Logprob(logprob=-0.004810901824384928, rank=1, decoded_token='[user]')}, {29871: Logprob(logprob=-0.12253902107477188, rank=1, decoded_token='')}, {14350: Logprob(logprob=-0.0002803409588523209, rank=1, decoded_token='Write')}, {263: Logprob(logprob=-0.0004917366313748062, rank=1, decoded_token='a')}, {3758: Logprob(logprob=-0.000431

In [5]:
result[0]

RequestOutput(request_id=0, prompt='A robot may not injure a human being', prompt_token_ids=[1, 319, 19964, 1122, 451, 10899, 545, 263, 5199, 1641], prompt_logprobs=[None, {319: Logprob(logprob=-4.5557684898376465, rank=8, decoded_token='A'), 917: Logprob(logprob=-2.5245184898376465, rank=1, decoded_token='Tags')}, {19964: Logprob(logprob=-8.996478080749512, rank=1474, decoded_token='robot'), 716: Logprob(logprob=-3.735247850418091, rank=1, decoded_token='new')}, {1122: Logprob(logprob=-4.979100227355957, rank=22, decoded_token='may'), 293: Logprob(logprob=-1.4712878465652466, rank=1, decoded_token='ic')}, {451: Logprob(logprob=-2.097038745880127, rank=2, decoded_token='not'), 367: Logprob(logprob=-1.3626636266708374, rank=1, decoded_token='be')}, {10899: Logprob(logprob=-0.5526919960975647, rank=1, decoded_token='inj')}, {545: Logprob(logprob=-0.0006244616815820336, rank=1, decoded_token='ure')}, {263: Logprob(logprob=-0.027345167472958565, rank=1, decoded_token='a')}, {5199: Logprob(