# Benchmark for TPU using pytorch

This code is going to do some computational test about the performance that a TPU can obtain. It's an adaptation from my previuous benchmark using pytorch. However, the script to use pytorch-xla (the module that uses the TPU) it's only available to use with pytorch 1.6, it's not available to use it with current version (1.7), so, the BenchMark module from pytorch it's not included and it has to be replaced it by using the timeit module.

Using timeit module made that executions can have a warmp up delay of a 2 us approximately.

In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py  --apt-packages libomp5 libopenblas-dev # --version=pytorch-1.8

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  51160      0 --:--:-- --:--:-- --:--:-- 51160
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200515 ...
Found existing installation: torch 1.7.0
Uninstalling torch-1.7.0:
Done updating TPU runtime
  Successfully uninstalled torch-1.7.0
Found existing installation: torchvision 0.8.1
Uninstalling torchvision-0.8.1:
  Successfully uninstalled torchvision-0.8.1
Copying gs://tpu-pytorch/wheels/torch-nightly+20200515-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/91.0 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200515-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/119.5 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-nightly+202

In [2]:
# imports pytorch
import torch

print(torch.__version__)

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm
import platform
import os
#Importing Libraries needed for use torch
import timeit
#import torch.utils.benchmark as benchmark #torch_xla it is not compatible with 1.7, where it is the benchmark library
from itertools import product

1.6.0a0+bf2bbd9


In [3]:
#Functions obtained from Torch Webpages por PyTorch Benchmarks
def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to bmm'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)

In [4]:
def benchMark(sizes,nThreads,dev):   
    for n in sizes:
        x = torch.ones((n, n),device=dev)
        t0 = timeit.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals={'x': x})

        t1 = timeit.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals={'x': x})

        print('size of square matrix: ',n)
        print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
        print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us\n')

In [5]:
dev = xm.xla_device()

sizes = [512,2048,4096,8192,16384,32768,65536,131072,262144]
threads = [1] #We put a single thread
compares = []

#Verifying the correct use of methods from above
#Just to verify that works properly in tpu
for n in sizes:
    x = torch.ones(n, n,device=dev)
    #if assert don't show error, it means works properly
    assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

In [6]:
#Generate a file.out with the results.
#Benchmark from pytorch just generate a print from the sdtout, so we need to change the stdout to write it in a file.
import sys

original_stdout = sys.stdout # Save a reference to the original standard output

with open('output_benchmark.out', 'w') as file:
    sys.stdout = file # Change the standard output to the file we created.
    #The benchmark execute 5 times to gather data and afterwards 
    for i in range(0,5):
        print("Benchmark execution: ",i+1, "\n")
        benchMark(sizes,threads,dev)

sys.stdout = original_stdout # Reset the standard output to its original value

In [7]:
#Printing the results
with open('output_benchmark.out', 'r') as file:
    for line in file.readlines():
        print(line)
    

Benchmark execution:  1 



size of square matrix:  512

mul_sum(x, x):   14.8 us

bmm(x, x):       37.4 us



size of square matrix:  2048

mul_sum(x, x):    8.4 us

bmm(x, x):       33.8 us



size of square matrix:  4096

mul_sum(x, x):   10.3 us

bmm(x, x):       24.6 us



size of square matrix:  8192

mul_sum(x, x):    7.4 us

bmm(x, x):       34.2 us



size of square matrix:  16384

mul_sum(x, x):    7.0 us

bmm(x, x):       23.4 us



size of square matrix:  32768

mul_sum(x, x):    6.4 us

bmm(x, x):       23.4 us



size of square matrix:  65536

mul_sum(x, x):    7.0 us

bmm(x, x):       23.1 us



size of square matrix:  131072

mul_sum(x, x):    6.8 us

bmm(x, x):       23.2 us



size of square matrix:  262144

mul_sum(x, x):    6.7 us

bmm(x, x):       24.7 us



Benchmark execution:  2 



size of square matrix:  512

mul_sum(x, x):    6.9 us

bmm(x, x):       23.5 us



size of square matrix:  2048

mul_sum(x, x):    6.4 us

bmm(x, x):       23.8 us



size of square 