In [1]:
import torch
from torch.utils.benchmark import Compare, Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias, Timer

from binarize import su

In [2]:
opt_su = torch.compile(su, mode="reduce-overhead")

In [3]:
img_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter("h", minval=1, maxval=10000, distribution='loguniform'),
        FuzzedParameter("w", minval=1, maxval=10000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor("img", size=(1, "h", "w"), probability_contiguous=1, cuda=True)
    ],
    seed=0,
)

In [6]:
measurements = []
for tensors, _, _ in img_fuzzer.take(10):
    measurement = Timer(
        stmt="su(img)",
        setup="from __main__ import su",
        globals=tensors,
        label="su",
        description="vanilla",
    ).blocked_autorange(min_run_time=1)
    opt_measurement = Timer(
        stmt="opt_su(img)",
        setup="from __main__ import opt_su",
        globals=tensors,
        label="su",
        description="compile",
    ).blocked_autorange(min_run_time=1)
    measurements.extend([measurement, opt_measurement])

In [7]:
compare = Compare(measurements)
compare.print()

[------------------ su -----------------]
                   |  vanilla  |  compile
1 threads: ------------------------------
      su(img)      |    74.3   |         
      opt_su(img)  |           |    72.4 

Times are in milliseconds (ms).

