In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
import torch 
from torchvision import transforms
from torchvision.datasets import MNIST

In [4]:
dataset = MNIST('/workspace/data/', download=True, transform=transforms.ToTensor())
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspace/data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
dataset.data.shape

torch.Size([60000, 28, 28])

In [6]:
n_train = 50_000
n_valid = dataset.data.shape[0] - n_train
x_train, y_train = dataset.data[:n_train, :, :].view(n_train, -1) / 255, dataset.targets[:n_train]
x_valid, y_valid = dataset.data[n_train:, :, :].view(n_valid, -1) / 255, dataset.targets[n_train:]

In [7]:
weights = torch.randn(784, 10)
biases  = torch.zeros(10)

In [8]:
def matmul(a: torch.tensor, b: torch.tensor) -> torch.Tensor:
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            for k in range(ac):
                c[i, j] += a[i, k] * b[k, j]
    return c

Let's benchmark it against the PyTorch version. 

In [9]:
m1 = x_valid[:5]
m2 = weights

In [10]:
m1.shape, m2.shape

(torch.Size([5, 784]), torch.Size([784, 10]))

In [11]:
%time t1 = matmul(m1, m2)

CPU times: user 407 ms, sys: 0 ns, total: 407 ms
Wall time: 406 ms


In [12]:
%time t2 = torch.matmul(m1, m2)

CPU times: user 88 µs, sys: 37 µs, total: 125 µs
Wall time: 128 µs


In [13]:
def test_near(a: torch.tensor, b:torch.tensor):
    return torch.allclose(a, b, rtol=1e-3, atol=1e-5)

In [14]:
test_near(t1, t2)

True

The output is the same, but our version is significantly slower.

## Vectorization

In [15]:
def matmul_fast(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            c[i, j] = (a[i, :] * b[:, j]).sum()
    return c

In [16]:
%time t3 = matmul_fast(m1, m2)

CPU times: user 829 µs, sys: 353 µs, total: 1.18 ms
Wall time: 914 µs


In [17]:
test_near(t3, t2)

True

By converting the outer loop to a vectorized operation, we halved the wall time. The gain would be larger with matrices of higher rank.

## Broadcasting

But we can do better than that, and remove one more loop.

In [18]:
def matmul_faster(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    ar, ac = a.shape
    br, bc = b.shape
    assert ac == br
    
    c = torch.zeros(ar, bc)
    for i in range(ar):
        c[i, :] = (a[i, :].unsqueeze(-1) * b).sum(dim=0)
    return c

In [19]:
%time t4 = matmul_faster(m1, m2)

CPU times: user 428 µs, sys: 0 ns, total: 428 µs
Wall time: 278 µs


In [20]:
test_near(t4, t2)

True

## Einstein Summation

Let's see if we can improve on this, and maybe even beat the PyTorch implementation.

In [21]:
def matmul_fastest(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return torch.einsum('ik,kj->ij', a, b)

In [22]:
%time t5 = matmul_fastest(m1, m2)

CPU times: user 83 µs, sys: 36 µs, total: 119 µs
Wall time: 122 µs


In [23]:
test_near(t5, t2)

True

The `torch.einsum` version is almost as fast as `torch.matmul`. That's pretty remarkable.