Skip to content

Commit

Permalink
Merge branch 'main' into number_proxies
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Jun 4, 2024
2 parents 8fe43a9 + d1d581c commit fe6f311
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,45 @@ def is_requires_grad(type: ComputeType):
)


import functools


def timer_and_memory_stats(benchmark) -> float:
def deco(func):
@functools.wraps(func)
def wrapper():
ret = func()
benchmark.extra_info["max_allocated_memory(MB)"] = torch.cuda.max_memory_allocated() / (1024 * 1024.0)
torch.cuda.reset_peak_memory_stats()
return ret

return wrapper

return deco


from contextlib import contextmanager


@contextmanager
def record_peak_allocated_memory(benchmark):
old_timer = benchmark._timer
benchmark._timer = timer_and_memory_stats(benchmark)(benchmark._timer)
try:
yield
finally:
benchmark._timer = old_timer


def benchmark_for_compute_type(compute_type: ComputeType, benchmark, fn: Callable, args, kwargs):
match compute_type:
case ComputeType.INFERENCE | ComputeType.TRAINING_FORWARD:
benchmark(fn, *args, **kwargs)
case ComputeType.TRAINING_BACKWARD:
backward_fn, backward_setup = backward_only(fn, *args, **kwargs)
backward_args = backward_setup()
benchmark(backward_fn, *backward_args)
with record_peak_allocated_memory(benchmark):
match compute_type:
case ComputeType.INFERENCE | ComputeType.TRAINING_FORWARD:
benchmark(fn, *args, **kwargs)
case ComputeType.TRAINING_BACKWARD:
backward_fn, backward_setup = backward_only(fn, *args, **kwargs)
backward_args = backward_setup()
benchmark(backward_fn, *backward_args)


def interpreter_fwd(module: Callable):
Expand Down

0 comments on commit fe6f311

Please sign in to comment.