In [1]:
import numpy as np
import torch

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1a4cab4cc70>

爱因斯坦求和约定：用于简洁的表示乘积、点积、转置等方法。

矩阵相乘例子: 'ij,jk->ik'

https://zhuanlan.zhihu.com/p/434232512

https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py#L205

# 矩阵相乘

## 维度不需要调整

In [2]:
a = np.arange(12).reshape(3, 4)
b = np.arange(20).reshape(4, 5)

In [3]:
a_t = torch.from_numpy(a)
b_t = torch.from_numpy(b)

In [4]:
c0 = a @ b
c0.shape

(3, 5)

In [5]:
c0_t = a_t @ b_t
c0_t.shape

torch.Size([3, 5])

In [6]:
c1 = np.einsum("ij,jk->ik", a, b)  # 空格可有可无
np.all(c0 == c1)

True

In [7]:
c1_t = torch.einsum("ij,jk->ik", a_t, b_t)  # 空格可有可无
torch.all(c0_t == c1_t)

tensor(True)

## 维度需要调整

In [8]:
c2 = np.einsum("i j, k j -> i k", a, b.T)
np.all(c1 == c2)

True

In [9]:
c2_t = torch.einsum("i j, k j -> i k", a_t, b_t.T)
torch.all(c1_t == c2_t)

tensor(True)

# 多维矩阵相乘

In [10]:
a = np.random.random((1, 196, 768))
b = np.random.random((1, 768, 196))

In [11]:
a_t = torch.from_numpy(a)
b_t = torch.from_numpy(b)

In [12]:
c0 = a @ b
c0.shape

(1, 196, 196)

In [13]:
c0_t = a_t @ b_t
c0_t.shape

torch.Size([1, 196, 196])

## 不需要转置

In [14]:
c1 = np.einsum("b p c, b c k -> b p k", a, b)
c1.shape

(1, 196, 196)

In [37]:
np.all(c0 == c1), np.all(np.isclose(c0, c1))

(False, True)

In [16]:
c1_t = torch.einsum("b p c, b c k -> b p k", a_t, b_t)  # 空格可有可无
c1_t.shape

torch.Size([1, 196, 196])

In [17]:
torch.all(c0_t == c1_t)

tensor(True)

## 需要转置

In [18]:
c2 = np.einsum("b p c, b k c -> b p k", a, b.swapaxes(1, 2))
c2.shape

(1, 196, 196)

In [38]:
np.all(np.isclose(c0, c1)), np.all(c1 == c2)

(True, True)

In [20]:
c2_t = torch.einsum("b p c, b k c -> b p k", a_t, b_t.transpose(1, 2))
c2_t.shape

torch.Size([1, 196, 196])

In [21]:
torch.all(c0_t == c1_t), torch.all(c1_t == c2_t)

(tensor(True), tensor(True))

# 多矩阵乘法

In [22]:
a = torch.randn(3, 4)
b = torch.randn(3, 4, 5)
c = torch.randn(4, 5)
# 将三个矩阵共有的维度j合并
d = torch.einsum("i j, i j k, j k -> i k", a, b, c)
d.shape

torch.Size([3, 5])

In [23]:
d

tensor([[-0.7962,  0.1540, -0.7433, -0.7034, -0.1971],
        [ 0.0826,  0.1641,  0.3449, -0.6521,  9.4170],
        [-1.1619,  0.2474, -0.5765,  1.0773, -0.3230]])

# 求和

In [24]:
torch.einsum("i j ->", a)

tensor(-2.6891)

# 列求和

In [25]:
torch.einsum("i j -> i", a)

tensor([ 1.6113, -1.7591, -2.5412])

# 行求和

In [26]:
torch.einsum("i j -> j", a)

tensor([-0.8182, -0.4623, -2.3534,  0.9448])

# 点积

In [27]:
torch.einsum("i j , i j -> i j", a * 0.5, a * 5)

tensor([[1.0935, 0.1781, 0.0095, 0.9651],
        [0.5105, 0.0690, 5.7971, 0.3642],
        [2.6399, 0.7926, 1.9905, 0.0085]])

In [28]:
torch.einsum("i j , i j -> j i", a * 0.5, a * 5)

tensor([[1.0935, 0.5105, 2.6399],
        [0.1781, 0.0690, 0.7926],
        [0.0095, 5.7971, 1.9905],
        [0.9651, 0.3642, 0.0085]])

# 点积求和

In [29]:
torch.einsum("i j, i j ->", a * 0.5, a * 5)

tensor(14.4185)

In [30]:
torch.einsum("i j, i j -> j i", a * 0.5, a * 5).sum()

tensor(14.4185)

# 转置

In [31]:
a.size()

torch.Size([3, 4])

In [32]:
torch.einsum("i j -> j i", a)

tensor([[ 0.6614, -0.4519, -1.0276],
        [ 0.2669, -0.1661, -0.5631],
        [ 0.0617, -1.5228, -0.8923],
        [ 0.6213,  0.3817, -0.0583]])

# 广播+相乘

In [33]:
a = torch.arange(4)
a

tensor([0, 1, 2, 3])

In [34]:
b = torch.arange(3)
b

tensor([0, 1, 2])

In [35]:
torch.einsum("i, j -> i j", a, b)

# [0, 0, 0]
# [1, 1, 1]
# [2, 2, 2]
# [3, 3, 3]
# 列相乘
# [0, 1, 2]

tensor([[0, 0, 0],
        [0, 1, 2],
        [0, 2, 4],
        [0, 3, 6]])