In [45]:
import torch
import timeit

import torch.utils.benchmark as benchmark

## Setup

In [43]:
def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)

x = torch.randn(10000, 64)

## Benchmarking with timeit.Timer

In [44]:
t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, a)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x, 'a': x})

print(f'mul_sum(x, x):  {t0.timeit(10) / 10 * 1e6:>5.1f} us')

mul_sum(x, x):  3735.7 us


Create a Timer instance with the given statement, setup code and timer function and run its timeit() method with number executions. The optional globals argument specifies a namespace in which to execute the code.

## Benchmarking with torch.utils.benchmark.Timer

Even though the APIs are the same for the basic functionality, there are some important differences. benchmark.Timer.timeit() returns the time per run as opposed to the total runtime like timeit.Timer.timeit() does. PyTorch benchmark module also provides formatted string representations for printing the results.

Another important difference, and the reason why the results diverge is that PyTorch benchmark module runs in a single thread by default. We can change the number of threads with the num_threads argument. Running benchmark with all threads available gives similar results as the timeit module. More importantly, which version is faster depends on how many threads we run the code with. 

Another important difference is that PyTorch’s benchmark module takes care of warmups, whereas timeit module doesn't. 

In [47]:
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

print(t0.timeit(10))

<torch.utils.benchmark.utils.common.Measurement object at 0x10c79cd60>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  698.03 us
  1 measurement, 10 runs , 1 thread


#### Adding more threads

In [51]:
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    num_threads=10,
    globals={'x': x})

print(t0.timeit(10))

<torch.utils.benchmark.utils.common.Measurement object at 0x10c68f2b0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  1.14 ms
  1 measurement, 10 runs , 10 threads


#### Playing with more attributes
torch.utils.benchmark.Timer takes several additional arguments including: label, sub_label, description and env which change the __repr__ of the measurement object returned and are used for grouping the results (more on this later)

In [52]:
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=10,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

print(t0.timeit(10))

<torch.utils.benchmark.utils.common.Measurement object at 0x10c7e5a00>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  588.45 us
  1 measurement, 10 runs , 10 threads
