Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How the performance VS vLLM inference(vLLM vs Lookahead) #3

Open
buptygz opened this issue Jan 18, 2024 · 5 comments
Open

How the performance VS vLLM inference(vLLM vs Lookahead) #3

buptygz opened this issue Jan 18, 2024 · 5 comments

Comments

@buptygz
Copy link

buptygz commented Jan 18, 2024

In the benchmark comparison results, could we add a comparison with VLLM to see the acceleration effects?

@chenliangjyj
Copy link
Collaborator

chenliangjyj commented Jan 18, 2024

We haven't done the full benchmark yet. But following the script in vllm repo https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_latency.py, you can do simple benchmark

import argparse
import time
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from tqdm import tqdm

# from vllm import LLM, SamplingParams
from pia.lookahead.common.lookahead_cache import LookaheadCache
from pia.lookahead.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer


def main(args: argparse.Namespace):
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
    # the engine will automatically process the request in multiple batches.
    llm = LlamaForCausalLM.from_pretrained(args.model
                                         , cache_dir='../'
                                         , torch_dtype=torch.float16
                                         , low_cpu_mem_usage=True
                                         , device_map='auto'
                                         )

    # sampling_params = SamplingParams(
    #     n=args.n,
    #     temperature=0.0 if args.use_beam_search else 1.0,
    #     top_p=1.0,
    #     use_beam_search=args.use_beam_search,
    #     ignore_eos=True,
    #     max_tokens=args.output_len,
    # )
    # print(sampling_params)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    stop_word_ids = set(tokenizer.convert_tokens_to_ids([',', '.', ' ']))
    # lookahead_cache = LookaheadCache(eos=tokenizer.eos_token_id, stop_words=stop_ids)
    # model.lookahead_cache = lookahead_cache
    
    prompt = "Hello, I'm am conscious and"
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids.cuda()
    attention_mask = inputs.attention_mask.cuda()
    position_ids = None

    debug_lookahead = False
    decoding_length = 63
    branch_length = 12
    max_new_tokens = args.output_len
    decoding_kwargs = {"use_lookahead": True,
                       "debug_lookahead": debug_lookahead,
                       "decoding_mode": 'hier',
                       "decoding_length": decoding_length,
                       "branch_length": branch_length,
                       "stop_word_ids": stop_word_ids}

    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
                decoding_kwargs = {"use_lookahead": True,
                    "debug_lookahead": debug_lookahead,
                    "decoding_mode": 'hier',
                    "decoding_length": decoding_length,
                    "branch_length": branch_length,
                    "stop_word_ids": stop_word_ids}
                llm.generate(input_ids=input_ids,
                             attention_mask=attention_mask,
                             position_ids=position_ids,
                             pad_token_id=tokenizer.eos_token_id,
                             eos_token_id=tokenizer.eos_token_id,
                             use_cache=True,
                             max_new_tokens=max_new_tokens,
                             repetition_penalty=1.0,
                             do_sample=False,
                             decoding_kwargs=decoding_kwargs
                             )
            print(p.key_averages())
        else:
            decoding_kwargs = {"use_lookahead": True,
                                "debug_lookahead": debug_lookahead,
                                "decoding_mode": 'hier',
                                "decoding_length": decoding_length,
                                "branch_length": branch_length,
                                "stop_word_ids": stop_word_ids}
            start_time = time.perf_counter()
            llm.generate(input_ids=input_ids,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        pad_token_id=tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        use_cache=True,
                        max_new_tokens=max_new_tokens,
                        repetition_penalty=1.0,
                        do_sample=False,
                        decoding_kwargs=decoding_kwargs
                        )
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency

    print("Warming up...")
    run_to_completion(profile_dir=None)

    if args.profile:
        profile_dir = args.profile_result_dir
        if not profile_dir:
            profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
        print(f"Profiling (results will be saved to '{profile_dir}')...")
        run_to_completion(profile_dir=args.profile_result_dir)
        return

    # Benchmark.
    latencies = []
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
        latencies.append(run_to_completion(profile_dir=None))
    print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Benchmark the latency of processing a single batch of '
        'requests till completion.')
    parser.add_argument('--model', type=str, default='/mntnlp/common_base_model/llama2-7b')
    parser.add_argument('--tokenizer', type=str, default=None)
    parser.add_argument('--quantization',
                        '-q',
                        choices=['awq', 'gptq', 'squeezellm', None],
                        default=None)
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--n',
                        type=int,
                        default=1,
                        help='Number of generated sequences per prompt.')
    parser.add_argument('--use-beam-search', action='store_true')
    parser.add_argument('--num-iters',
                        type=int,
                        default=3,
                        help='Number of iterations to run.')
    parser.add_argument('--trust-remote-code',
                        action='store_true',
                        help='trust remote code from huggingface')
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
        help=(
            'path to save the pytorch profiler output. Can be visualized '
            'with ui.perfetto.dev or Tensorboard.'
        ))
    args = parser.parse_args()
    main(args)

Our method is aroud 0.33611 s while vllm around 1.3923719 s on A100 80G ,cu117. We focus on single query latency.

@HMJW
Copy link

HMJW commented Jan 18, 2024

Can this algorithm be integrated into vLLM? Would the performance be better if can?

@chenliangjyj
Copy link
Collaborator

The idea is actually feasible. However, we have not yet tested whether our approach will cause the gpu to reach compute bound too fast, thereby affecting the overall throughput under extremely high throughput conditions. In cases where QPS is not too high and GPU utilization not always maintain on a high level, our method has shown an advantage in terms of single query latency in internal experiments.

@feifeibear
Copy link

The idea is actually feasible. However, we have not yet tested whether our approach will cause the gpu to reach compute bound too fast, thereby affecting the overall throughput under extremely high throughput conditions. In cases where QPS is not too high and GPU utilization not always maintain on a high level, our method has shown an advantage in terms of single query latency in internal experiments.

I also suggest you guys apply lookahead in a serving framework.

@zongquanliu
Copy link

Is there any update on this issue? @chenliangjyj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants