einsum (Einstein Summation) is a concise way to express tensor contractions, reductions, outer products, and attention ops using index notation.

Dot Product

In [2]:
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

out = torch.einsum('i,i->', a, b)
print(out)

tensor(32.)


Outer Product

In [3]:
a = torch.tensor([1., 2.])
b = torch.tensor([4., 5., 6.])

out = torch.einsum('i,j->ij', a, b)
print(out)

tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.]])


Matrix Vector Multiplication

In [4]:
A = torch.tensor([[1., 2., 3.], [4., 5., 6.]])  # (2, 3)
x = torch.tensor([7., 8., 9.])                 # (3,)

out = torch.einsum('ij,j->i', A, x)
print(out)

tensor([ 50., 122.])


Matrix Matrix Multiplication

In [5]:
A = torch.randn(2, 3)
B = torch.randn(3, 4)

out = torch.einsum('ik,kj->ij', A, B)
print(out)

tensor([[ 0.0750,  0.5085,  0.8904, -2.2504],
        [-0.2206,  0.1441,  1.8318, -0.2015]])


Batched Matrix Multiplication

In [6]:
A = torch.randn(10, 2, 3)
B = torch.randn(10, 3, 4)

out = torch.einsum('bij,bjk->bik', A, B)
print(out)

tensor([[[ 0.0195, -0.7816, -0.8763,  2.2061],
         [ 0.5668,  1.1220, -1.2224,  0.1161]],

        [[-0.3294, -0.1741,  0.4972,  3.6834],
         [-0.7849, -2.6441, -1.6246,  1.5197]],

        [[ 1.3685,  1.8194, -0.4677,  1.3664],
         [-0.9079,  1.1539, -2.0940, -1.2011]],

        [[-2.2159, -0.4520,  0.2183,  0.0462],
         [ 1.6361,  0.1408, -1.5980,  0.5484]],

        [[ 0.4013,  0.6323,  1.6513, -2.1926],
         [ 2.7024,  0.3602,  0.8489, -2.5955]],

        [[ 0.7844, -0.6573,  2.4862,  2.8324],
         [ 0.2876,  0.4176, -1.8220, -2.4041]],

        [[-1.1543,  1.0726, -1.5795,  1.2047],
         [-1.0551,  1.3986,  0.6860,  0.9939]],

        [[ 2.4141, -2.3396, -0.7182, -1.0035],
         [ 0.6322,  0.8609,  0.8573, -0.2625]],

        [[ 0.9789,  1.9025,  0.1274, -2.3324],
         [ 1.0936,  0.4009, -0.1842, -0.9831]],

        [[ 0.2554,  0.1552, -2.9759,  0.2129],
         [ 3.7971,  0.1569, -2.4792,  2.4251]]])


Softmax Attention Scores


Computes Q × Kᵀ per head

Each score: dot product between query & key.

In [7]:
Q = torch.randn(64, 8, 20, 64)  # (batch, heads, queries, dim)
K = torch.randn(64, 8, 64, 64)  # (batch, heads, keys, dim)

scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
print(scores.shape)

torch.Size([64, 8, 20, 64])


Context Vector via Attention

In [9]:
attn = torch.randn(64, 8, 20, 64) # attention weights
V = torch.randn(64, 8, 64, 64) # values
context = torch.einsum('bhqk,bhkd->bhqd', attn, V) # weighed sum of values using attention scores

print(context.shape)

torch.Size([64, 8, 20, 64])


Is the operation:

🔲 A simple dot/matmul/batch matmul?
    → Use matmul / bmm

🔲 Complex contraction, like multi-axis or attention?
    → Use einsum

🔲 You care about speed or memory?
    → Profile first. Rewrite einsum to bmm/matmul if needed.


Profiling *einsum*

In [1]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

a = torch.randn(512, 512).cuda()
b = torch.randn(512, 512).cuda()

with profile(activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]) as profiling:
  with record_function("einsum"):
    torch.einsum('ik,kj->ij', a, b)

print(profiling.key_averages().table(sort_by = "cuda_time_total"))

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           einsum         7.29%      11.855ms        99.99%     162.586ms     162.586ms       0.000us         0.00%     788.904us     788.904us             1  
                                     aten::einsum         8.71%      14.154ms        92.70%     150.731ms     150.731ms       0.000us         0.00%     788.904us     788.904us             1  
                                       

Profiling *bmm* and *matmul*

In [5]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

B, M, K, N = 128, 64, 64, 64
A = torch.randn(B, M, K, device='cuda')
B_ = torch.randn(B, K, N, device='cuda')

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    torch.cuda.synchronize()

    with record_function("einsum"):
        torch.einsum('bik,bkj->bij', A, B_)
    torch.cuda.synchronize()

    with record_function("bmm"):
        torch.bmm(A, B_)
    torch.cuda.synchronize()

    with record_function("matmul"):
        torch.matmul(A, B_)
    torch.cuda.synchronize()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::bmm        12.75%     158.883us        65.15%     811.618us     270.539us     145.338us       100.00%     145.338us      48.446us             3  
     volta_sgemm_64x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us     145.338us       100.00%     145.338us      48.446us             3  
                      bmm         4.58%      57.023us         9.47%     118.011us     118.011us       0.000us         0.00%      48.734us      48.734us        