diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 97038e787..5fe470d04 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -71,14 +71,45 @@ def is_requires_grad(type: ComputeType): ) +import functools + + +def timer_and_memory_stats(benchmark) -> float: + def deco(func): + @functools.wraps(func) + def wrapper(): + ret = func() + benchmark.extra_info["max_allocated_memory(MB)"] = torch.cuda.max_memory_allocated() / (1024 * 1024.0) + torch.cuda.reset_peak_memory_stats() + return ret + + return wrapper + + return deco + + +from contextlib import contextmanager + + +@contextmanager +def record_peak_allocated_memory(benchmark): + old_timer = benchmark._timer + benchmark._timer = timer_and_memory_stats(benchmark)(benchmark._timer) + try: + yield + finally: + benchmark._timer = old_timer + + def benchmark_for_compute_type(compute_type: ComputeType, benchmark, fn: Callable, args, kwargs): - match compute_type: - case ComputeType.INFERENCE | ComputeType.TRAINING_FORWARD: - benchmark(fn, *args, **kwargs) - case ComputeType.TRAINING_BACKWARD: - backward_fn, backward_setup = backward_only(fn, *args, **kwargs) - backward_args = backward_setup() - benchmark(backward_fn, *backward_args) + with record_peak_allocated_memory(benchmark): + match compute_type: + case ComputeType.INFERENCE | ComputeType.TRAINING_FORWARD: + benchmark(fn, *args, **kwargs) + case ComputeType.TRAINING_BACKWARD: + backward_fn, backward_setup = backward_only(fn, *args, **kwargs) + backward_args = backward_setup() + benchmark(backward_fn, *backward_args) def interpreter_fwd(module: Callable):