Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
dataset flags to the command below, and make some changes to the tokenizer logic in the
benchmark script (get_tokenizer and sample_requests func) to use your tokenizer correctly.
* Add `--save-result` flag to save the benchmark result to a json file in current folder.
* Add `--threads` flag to set the maximum number of threads used for request dispatching.

(run with real model and engines)
python -m benchmarks.benchmark_serving \
Expand All @@ -44,6 +45,7 @@

import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
import json
Expand Down Expand Up @@ -229,9 +231,11 @@ async def send_request(
session_cache: str,
priority: int,
max_tokens: int,
threads: int,
) -> RequestFuncOutput:
"""Send the request to wiz server."""
loop = asyncio.get_running_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=threads))
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
additional_text=prompt,
Expand Down Expand Up @@ -260,6 +264,7 @@ async def benchmark(
disable_tqdm: bool,
session_cache: str,
priority: int,
threads: int,
):
"""Benchmark the online serving performance."""
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -280,6 +285,7 @@ async def benchmark(
session_cache=session_cache,
priority=priority,
max_tokens=output_len,
threads=threads,
)
)
)
Expand Down Expand Up @@ -329,6 +335,17 @@ async def benchmark(
return result


def mock_requests(total_mock_requests: int):
"""Generates a list of mock requests containing mock data."""
data = []
for _ in range(total_mock_requests):
name = f"Item {random.randint(1, 1000)}"
price = random.randint(10, 100)
quantity = random.randint(1, 10)
data.append((name, price, quantity))
return data


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
Expand All @@ -341,7 +358,7 @@ def main(args: argparse.Namespace):

tokenizer = get_tokenizer(tokenizer_id)
if tokenizer == "test" or args.dataset == "test":
input_requests = [("AB", 2, 3)]
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, 3)]
else:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

Expand All @@ -354,6 +371,7 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
session_cache=args.session_cache,
priority=args.priority,
threads=args.threads,
)
)

Expand Down Expand Up @@ -438,6 +456,18 @@ def main(args: argparse.Namespace):
"the request arrival times."
),
)
parser.add_argument(
"--threads",
type=int,
default=110,
help="The maximum number of threads used for request dispatching.",
)
parser.add_argument(
"--total-mock-requests",
type=int,
default=150,
help="The maximum number of mock requests to send for benchmark testing.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--disable-tqdm",
Expand Down