# Demonstrate torch + MPS (mac "metal")

In [41]:
import torch
if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


## Benchmark example
https://pytorch.org/tutorials/recipes/recipes/benchmark.html

In [42]:
import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64, device=torch_device)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

In [43]:
import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

mul_sum(x, x):   63.0 us
bmm(x, x):       23.1 us


In [44]:
import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))

<torch.utils.benchmark.utils.common.Measurement object at 0x136111ed0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  63.37 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x136111ed0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  21.67 us
  1 measurement, 100 runs , 1 thread


In [45]:
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))

Benchmarking on 8 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x135e43610>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  183.99 us
  1 measurement, 100 runs , 8 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x136113710>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  19.22 us
  1 measurement, 100 runs , 8 threads


In [46]:
x = torch.randn(10000, 1024, device=torch_device)

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warm-up
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

mul_sum(x, x):   67.7 us
mul_sum(x, x):   66.3 us
bmm(x, x):      100.6 us
bmm(x, x):      134.5 us


In [47]:
from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 4, 8, 32]:
        print(b,n,num_threads)
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

1 1 1
1 1 4
1 1 8
1 1 32
1 64 1
1 64 4
1 64 8
1 64 32
1 1024 1
1 1024 4
1 1024 8
1 1024 32
1 10000 1
1 10000 4
1 10000 8
1 10000 32
64 1 1
64 1 4
64 1 8
64 1 32
64 64 1
64 64 4
64 64 8
64 64 32
64 1024 1
64 1024 4
64 1024 8
64 1024 32
64 10000 1
64 10000 4
64 10000 8
64 10000 32
1024 1 1
1024 1 4
1024 1 8
1024 1 32
1024 64 1
1024 64 4
1024 64 8
1024 64 32
1024 1024 1
1024 1024 4
1024 1024 8
1024 1024 32
1024 10000 1
1024 10000 4
1024 10000 8
1024 10000 32
10000 1 1
10000 1 4
10000 1 8
10000 1 32
10000 64 1
10000 64 4
10000 64 8
10000 64 32
10000 1024 1
10000 1024 4
10000 1024 8
10000 1024 32
10000 10000 1
10000 10000 4
10000 10000 8
10000 10000 32
[-------------- Batched dot ---------------]
                      |  mul/sum  |    bmm  
1 threads: ---------------------------------
      [1, 1]          |      2.0  |      3.3
      [1, 64]         |      2.0  |      3.3
      [1, 1024]       |      2.0  |      5.1
      [1, 10000]      |      3.0  |     10.6
      [64, 1]         |      