diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index cb1b4965..45459f4e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -342,7 +342,7 @@ def gen_mmlu_qa(data: Any, mmlu_method: str = "") -> str: f"(D) {row['D']}\n" ) - output += "\nCorrect answer: " + output += "\nCorrect answer:" if mmlu_method == "HELM": output += f"({row['answer']})\n\n" @@ -938,7 +938,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--num-prompts", type=int, - default=1000, + default=-1, help=( "Number of prompts to process. (number of sample requests we randomly" " collect from dataset)" @@ -1133,11 +1133,16 @@ def main(args: argparse.Namespace): # A given args.max_output_length value is the max generation step, # when the args.max_output_length is default to None, the sample's golden # output length will be used to decide the generation step. + if args.num_prompts == -1: + num_requests = len(dataset) + else: + num_requests = args.num_prompts + input_requests = sample_requests( dataset=dataset, tokenizer=tokenizer, use_chat_template=use_chat_template, - num_requests=args.num_prompts, + num_requests=num_requests, dataset_type=args.dataset, max_output_length=args.max_output_length, min_input_length=args.min_input_length,