Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 118 additions & 93 deletions scripts/generate_baseline_time_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import multiprocessing as mp
import time
import einops
import pydra
from pydra import Config, REQUIRED

"""
Generate baseline time for KernelBench
Expand Down Expand Up @@ -48,6 +50,28 @@

TIMING_DIR = os.path.join(REPO_TOP_PATH, "results", "timing")


class BaselineConfig(Config):
def __init__(self):
# Problem level to generate baseline for
self.level = REQUIRED

# GPU type for Modal (L40S, H100, A100, A100-80GB, L4, T4, A10G)
self.gpu = REQUIRED

# Hardware name for saving results
self.hardware_name = REQUIRED

# Batch size for parallel processing
self.batch_size = 10

# Timeout for each batch in seconds
self.timeout = 1800

# Number of trials for timing
self.num_trials = 100


# Modal Infra
import modal
app = modal.App("generate_baseline_modal")
Expand Down Expand Up @@ -127,7 +151,7 @@ def fetch_ref_arch_from_dataset(dataset: list[str],
ref_arch_name = ref_arch_path.split("/")[-1]
return (ref_arch_path, ref_arch_name, ref_arch_src)

@app.cls(image=image, container_idle_timeout=5)
@app.cls(image=image, scaledown_window=5)
class EvalFunc:

@modal.method()
Expand Down Expand Up @@ -188,121 +212,122 @@ def measure_program_time(
except Exception as e:
print(f"[Eval] Error in Measuring Performance: {e}")

def measure_program_time_wrapper(*args, **kwargs):
def measure_program_time_wrapper(gpu_type, *args, **kwargs):
with app.run():
return EvalFunc.with_options(gpu=gpu)().measure_program_time.remote(*args, **kwargs)
return EvalFunc.with_options(gpu=gpu_type)().measure_program_time.remote(*args, **kwargs)

def record_baseline_times(use_torch_compile: bool = False,
torch_compile_backend: str="inductor",
def record_baseline_times(config: BaselineConfig,
use_torch_compile: bool = False,
torch_compile_backend: str="inductor",
torch_compile_options: str="default",
file_name: str="baseline_time.json"):
"""
Generate baseline time for KernelBench,
Generate baseline time for KernelBench,
configure profiler options for PyTorch
save to specified file
"""
json_results = []

for level in [1, 2, 3]:
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
num_problems = len(dataset)
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))]

with tqdm(total=len(total_work), desc="Processing batches") as pbar:
while len(total_work) > 0:
curr_work_batch = total_work[:batch_size]
total_work = total_work[batch_size:] # pop the first batch_size elements

with mp.Pool() as pool:

work_args = [
(
ref_arch_name,
ref_arch_src,
100,
use_torch_compile,
torch_compile_backend,
torch_compile_options,
torch.device(f"cuda:0"),
False # do not print
)
for i, (p_id, ref_arch_path, ref_arch_name, ref_arch_src) in enumerate(curr_work_batch)
]
level = config.level
PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level))
dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR)
num_problems = len(dataset)
total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))]

with tqdm(total=len(total_work), desc="Processing batches") as pbar:
while len(total_work) > 0:
curr_work_batch = total_work[:config.batch_size]
total_work = total_work[config.batch_size:] # pop the first batch_size elements

with mp.Pool() as pool:

work_args = [
(
config.gpu,
ref_arch_name,
ref_arch_src,
config.num_trials,
use_torch_compile,
torch_compile_backend,
torch_compile_options,
torch.device(f"cuda:0"),
False # do not print
)
for i, (p_id, ref_arch_path, ref_arch_name, ref_arch_src) in enumerate(curr_work_batch)
]

start_time = time.time()

start_time = time.time()
async_results = []
for work_arg in work_args:
async_results.append(
pool.apply_async(measure_program_time_wrapper, work_arg)
)

async_results = []
for work_arg in work_args:
async_results.append(
pool.apply_async(measure_program_time_wrapper, work_arg)
batch_timeout = config.timeout
for i, async_result in enumerate(async_results):
problem_id, _, ref_arch_name, _ = curr_work_batch[i]

try:
elapsed_time = time.time() - start_time
remaining_time = max(0, batch_timeout - elapsed_time)
result = async_result.get(timeout=remaining_time)
json_results.append((f"level{level}", ref_arch_name, result))

except mp.TimeoutError:
print(
f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}"
)
json_results.append((f"level{level}", ref_arch_name, None))

batch_timeout = timeout
for i, async_result in enumerate(async_results):
problem_id, _, ref_arch_name, _ = curr_work_batch[i]

try:
elapsed_time = time.time() - start_time
remaining_time = max(0, batch_timeout - elapsed_time)
result = async_result.get(timeout=remaining_time)
json_results.append((f"level{level}", ref_arch_name, result))

except mp.TimeoutError:
print(
f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}"
)
json_results.append((f"level{level}", ref_arch_name, None))

except Exception as e:
print(
f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}: {str(e)}"
)
json_results.append((f"level{level}", ref_arch_name, None))
except Exception as e:
print(
f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}: {str(e)}"
)
json_results.append((f"level{level}", ref_arch_name, None))

pbar.update(len(curr_work_batch))
pbar.update(len(curr_work_batch))

save_path = os.path.join(TIMING_DIR, file_name)
write_batch_to_json(json_results, save_path)
return json_results


if __name__ == "__main__":
# DEBUG and simple testing
# test_measure_particular_program(2, 28)
gpu = "A10G"
# Replace this with whatever hardware you are running on
hardware_name = f"{gpu}_modal"
print(f"Generating baseline time for {hardware_name}")
# input(f"You are about to start recording baseline time for {hardware_name}, press Enter to continue...")
# # Systematic recording of baseline time

# if os.path.exists(os.path.join(TIMING_DIR, hardware_name)):
# input(f"Directory {hardware_name} already exists, Are you sure you want to overwrite? Enter to continue...")
@pydra.main(base=BaselineConfig)
def main(config: BaselineConfig):
"""
Generate baseline time for KernelBench problems using Modal GPUs
"""
print(f"Generating baseline time for level {config.level} on {config.gpu} Modal")
print(f"Hardware name: {config.hardware_name}")
print(f"Batch size: {config.batch_size}, Timeout: {config.timeout}s, Num trials: {config.num_trials}")

# 1. Record Torch Eager
record_baseline_times(use_torch_compile=False,
torch_compile_backend=None,
torch_compile_options=None,
file_name=f"{hardware_name}/baseline_time_torch.json")

record_baseline_times(use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
file_name=f"{hardware_name}/baseline_time_torch_compile_inductor_default.json")

# 2. Record Torch Compile using Inductor
# for torch_compile_mode in ["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]:
# record_baseline_times(use_torch_compile=True,
# torch_compile_backend="inductor",
# torch_compile_options=torch_compile_mode,
# file_name=f"{hardware_name}/baseline_time_torch_compile_inductor_{torch_compile_mode}.json")

# 3. Record Torch Compile using cudagraphs
# record_baseline_times(use_torch_compile=True,
# torch_compile_backend="cudagraphs",
# torch_compile_options=None,
# file_name=f"{hardware_name}/baseline_time_torch_compile_cudagraphs.json")
print("\n[1/2] Recording baseline times with PyTorch Eager execution...")
record_baseline_times(
config=config,
use_torch_compile=False,
torch_compile_backend=None,
torch_compile_options=None,
file_name=f"{config.hardware_name}/baseline_time_torch.json"
)

# 2. Record Torch Compile using Inductor (default mode)
print("\n[2/2] Recording baseline times with Torch Compile (inductor, default mode)...")
record_baseline_times(
config=config,
use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
file_name=f"{config.hardware_name}/baseline_time_torch_compile_inductor_default.json"
)

print(f"\n✓ Baseline time generation complete!")
print(f"Results saved to: {os.path.join(TIMING_DIR, config.hardware_name)}")


if __name__ == "__main__":
main()



Expand Down