From 269710e8f15b9b92e1b9dffeece4575aeb193a74 Mon Sep 17 00:00:00 2001 From: ovoievodin Date: Fri, 26 Sep 2025 15:53:48 -0700 Subject: [PATCH] chore: support specific query --- runners/datafusion-comet/tpcbench.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/runners/datafusion-comet/tpcbench.py b/runners/datafusion-comet/tpcbench.py index ed11938..362a03c 100644 --- a/runners/datafusion-comet/tpcbench.py +++ b/runners/datafusion-comet/tpcbench.py @@ -21,7 +21,7 @@ from pyspark.sql import SparkSession import time -def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str): +def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str, query_num: int = None): # Initialize a SparkSession spark = SparkSession.builder \ @@ -60,7 +60,16 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu for iteration in range(0, iterations): print(f"Starting iteration {iteration} of {iterations}") - for query in range(1, num_queries+1): + # Determine which queries to run + if query_num is not None: + # Validate query number + if query_num < 1 or query_num > num_queries: + raise ValueError(f"Query number {query_num} is out of range. Valid range is 1-{num_queries} for {benchmark}") + queries_to_run = [query_num] + else: + queries_to_run = range(1, num_queries+1) + + for query in queries_to_run: spark.sparkContext.setJobDescription(f"{benchmark} q{query}") # read text file @@ -108,6 +117,7 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu parser.add_argument("--queries", required=True, help="Path to query files") parser.add_argument("--iterations", required=False, default="1", help="How many iterations to run") parser.add_argument("--output", required=True, help="Path to write output") + parser.add_argument("--query", required=False, type=int, help="Specific query number to run (1-based). If not specified, all queries will be run.") args = parser.parse_args() - main(args.benchmark, args.data, args.queries, int(args.iterations), args.output) + main(args.benchmark, args.data, args.queries, int(args.iterations), args.output, args.query)