From b99016cb7eed87641eb96d005f71bca6e21cab03 Mon Sep 17 00:00:00 2001 From: Hong-Yi Lin Date: Thu, 28 Aug 2025 04:04:29 +0000 Subject: [PATCH] Fix error in src/benchmark_utils.py --- src/benchmark_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/benchmark_utils.py b/src/benchmark_utils.py index 742c75f..ee1c478 100644 --- a/src/benchmark_utils.py +++ b/src/benchmark_utils.py @@ -19,7 +19,7 @@ import shutil -def simple_timeit(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float: +def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float: """Simple utility to time a function for multiple runs.""" assert task is not None @@ -97,7 +97,7 @@ def is_local_directory_path(dir: str) -> bool: return dir.startswith("/") or dir.startswith("./") or dir.startswith("../") -def timeit_from_trace(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float: +def timeit_from_trace(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float: """ Time a function with jax.profiler and get the run time from the trace. """ @@ -105,7 +105,12 @@ def timeit_from_trace(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) jax.block_until_ready(f(*args)) # warm it up! - trace_name = f"{task}_dim_{matrix_dim}" + if matrix_dim is not None: + trace_name = f"{task}_dim_{matrix_dim}" + else: + trace_name = f"t_{task}_" + "".join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ) trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir