## Speed test: `torch.matmul` vs `torch.einsum`

In [3]:
import timeit

In [4]:
bsz = 32
d = 512
A = torch.randn(bsz, d, d).cuda()
B = torch.randn(bsz, d, d).cuda()

In [5]:
import torch
import torch.nn as nn

class MatMulModule(nn.Module):
    def __init__(self):
        super(MatMulModule, self).__init__()
        self.wa = nn.Linear(d, d)
        self.wb = nn.Linear(d, d)

    def forward(self, x, y):
        x = self.wa(x)
        y = self.wb(y)
        return torch.matmul(x, y)

class EinsumModule(nn.Module):
    def __init__(self):
        super(EinsumModule, self).__init__()
        self.wa = nn.Linear(d, d)
        self.wb = nn.Linear(d, d)

    def forward(self, x, y):
        x = self.wa(x)
        y = self.wb(y)
        return torch.einsum('bij,bjk->bik', x, y)

In [6]:
def step(model, x, y):
    z = model(x, y)
    z.sum().backward()

device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
autocast_ctx_mgr = torch.autocast(device_type=device_type, dtype=torch.bfloat16)
# autocast_ctx_mgr = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)

def autocast_step(model, x, y):
    with autocast_ctx_mgr:
        z = model(x, y)
    z.sum().backward()

In [7]:
matmul_model = MatMulModule().cuda()
einsum_model = EinsumModule().cuda()

In [8]:
repeat = 1000
matmul_time = timeit.timeit(lambda: step(matmul_model, A, B), number=repeat)
einsum_time = timeit.timeit(lambda: step(einsum_model, A, B), number=repeat)
print(f"torch.matmul execution time: {(matmul_time/repeat)*1000:.2f} miliseconds")
print(f"torch.einsum execution time: {(einsum_time/repeat)*1000:.2f} miliseconds")

torch.matmul execution time: 3.65 miliseconds
torch.einsum execution time: 3.79 miliseconds


In [9]:
# autocast step
repeat = 1000
einsum_time = timeit.timeit(lambda: autocast_step(einsum_model, A, B), number=repeat)
matmul_time = timeit.timeit(lambda: autocast_step(matmul_model, A, B), number=repeat)
print(f"torch.matmul execution time: {(matmul_time/repeat)*1000:.2f} miliseconds")
print(f"torch.einsum execution time: {(einsum_time/repeat)*1000:.2f} miliseconds")

torch.matmul execution time: 0.64 miliseconds
torch.einsum execution time: 0.65 miliseconds


In [15]:
# compile the model
matmul_model_ = torch.compile(matmul_model)
einsum_model_ = torch.compile(einsum_model)
matmul_model_(A, B);
einsum_model_(A, B);

In [16]:

# run the compiled model
repeat = 1000
matmul_time = timeit.timeit(lambda: step(matmul_model_, A, B), number=repeat)
einsum_time = timeit.timeit(lambda: step(einsum_model_, A, B), number=repeat)
print(f"torch.matmul execution time: {(matmul_time/repeat)*1000:.2f} miliseconds")
print(f"torch.einsum execution time: {(einsum_time/repeat)*1000:.2f} miliseconds")

torch.matmul execution time: 3.54 miliseconds
torch.einsum execution time: 3.73 miliseconds


In [17]:
# compile the model + autocast
repeat = 1000
matmul_time = timeit.timeit(lambda: autocast_step(matmul_model_, A, B), number=repeat)
einsum_time = timeit.timeit(lambda: autocast_step(einsum_model_, A, B), number=repeat)
print(f"torch.matmul execution time: {(matmul_time/repeat)*1000:.2f} miliseconds")
print(f"torch.einsum execution time: {(einsum_time/repeat)*1000:.2f} miliseconds")

torch.matmul execution time: 0.65 miliseconds
torch.einsum execution time: 0.66 miliseconds


## `torch.einsum` vs `opt_einsum`

In [13]:
import torch
from opt_einsum import contract
import numpy as np

In [5]:
B, T, D, H, R = 32, 1024, 512, 8, 16

attn_scores_ = torch.randn(B, H, T, T).cuda()
attn_scores = torch.nn.functional.softmax(attn_scores_, dim=-1)

relations = torch.randn(B, T, T, R).cuda()

In [9]:
%timeit torch.einsum('bhij,bijr->bihr', attn_scores, relations) # (bs, seqlen, n_heads, n_relations)

5.35 ms ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit contract('bhij,bijr->bihr', attn_scores, relations, backend='torch') # (bs, seqlen, n_heads, n_relations)

The slowest run took 10.94 times longer than the fastest. This could mean that an intermediate result is being cached.
28.3 ms ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
B, T, D, H, R = 32, 256, 64, 4, 8

attn_scores_ = torch.randn(B, H, T, T)
attn_scores = torch.nn.functional.softmax(attn_scores_, dim=-1).numpy()

relations = torch.randn(B, T, T, R).numpy()

In [20]:
%timeit np.einsum('bhij,bijr->bihr', attn_scores, relations) # (bs, seqlen, n_heads, n_relations)

88.7 ms ± 216 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [21]:
%timeit contract('bhij,bijr->bihr', attn_scores, relations) # (bs, seqlen, n_heads, n_relations)

88.9 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
