In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pathlib2 import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor

### Matrix multiplication

#### with elementwise operations

In [3]:
a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

In [4]:
b = torch.randint(high=5, size=(3,3))

In [5]:
# a = a.float(); b = b.float()
"""
We will intentionally leave out casting to float to understand where
casting is essential and otherwise
"""

In [6]:
def matmul(a,b):
    ar,ac = a.shape
    br,bc = b.shape
    assert ar==bc
    c = torch.zeros(ar,bc)
    for i in range(ar):
        for j in range(bc):
            c[i,j] = (a[i,:]*b[:,j]).sum(dim=0)
    return c

In [7]:
%timeit -n 10 _=matmul(a,b)

126 µs ± 20.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### with broadcasting

Broadcasting only makes sense if one is the inputs needs it, so we will add another case 'c'. We will still test the performance on the square matrices

In [8]:
c = tensor([1, 2, 3])
# c = c.float()

In [9]:
def matmul_br(a,b):
    ar,ac = a.shape
    br,bc = b.shape
    assert ac==br
    c = torch.zeros(ar, bc)
    for i in range(ar):
        c[i]   = (a[i,:].unsqueeze(-1) * b).sum(dim=0)
    return c

In [10]:
%timeit -n 10 _=matmul_br(a,c[:,None])

49 µs ± 7.13 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit -n 10 _=matmul_br(a,b)

43.7 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


After re-running this notebook several times, I found that the square matrices are always faster and has less std

In [12]:
assert (matmul(a,b) == matmul_br(a,b)).all()

We will always check above condition for all variants of matmul_* with matmul being the base case

#### with einsum

In [13]:
def matmul_es(a,b): return torch.einsum('ik,kj->ij', a, b)

In [14]:
%timeit -n 10 _=matmul_es(a,b)

The slowest run took 4.07 times longer than the fastest. This could mean that an intermediate result is being cached.
23.3 µs ± 17 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


It looks like timeit is suggesting that something is wrong, however, it might be the case that arrays go into CPU cache only after the first run and hence the first run is the slowest.

__TODO: See if this can be avoided, other than putting everything on GPU__

In [15]:
%timeit -n 10 _=matmul_es(a,c[:,None])

The slowest run took 11.39 times longer than the fastest. This could mean that an intermediate result is being cached.
48.6 µs ± 68.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
assert (matmul(a,b) == matmul_es(a,b)).all()

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'other'

In [17]:
matmul(a,b).type(), matmul_es(a,b).type()

('torch.FloatTensor', 'torch.LongTensor')

Above error occurs because we did not cast our inputs to float but since we initialized the zero array within matmul with torch.zeros and changed it's values in place, the values were cast to float

For einsum, it simply reflected upon the input and gave out LongTensor, a simple fix can be to cast the output of einsum, let's check that

In [18]:
def matmul_esf(a,b): return torch.einsum('ik,kj->ij', a, b).float()

In [19]:
%timeit -n 10 _=matmul_esf(a,b)

The slowest run took 62.05 times longer than the fastest. This could mean that an intermediate result is being cached.
188 µs ± 409 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
%timeit -n 10 _=matmul_esf(a,c[:,None])

The slowest run took 111.74 times longer than the fastest. This could mean that an intermediate result is being cached.
374 µs ± 859 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [21]:
assert (matmul(a,b) == matmul_esf(a,b)).all()

#### with pytorch

In [22]:
%timeit -n 10 _=a.matmul(b)

The slowest run took 5.87 times longer than the fastest. This could mean that an intermediate result is being cached.
3.09 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [23]:
%timeit -n 10 _=a.matmul(c)

The slowest run took 13.63 times longer than the fastest. This could mean that an intermediate result is being cached.
4.79 µs ± 7.53 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
assert (a.matmul(b) == matmul_es(a,b)).all()

In [26]:
%timeit -n 10 _=a@b

The slowest run took 15.11 times longer than the fastest. This could mean that an intermediate result is being cached.
5.64 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [27]:
%timeit -n 10 _=a@c

The slowest run took 29.17 times longer than the fastest. This could mean that an intermediate result is being cached.
13 µs ± 17.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
assert (a@b == matmul_es(a,b)).all()