In [1]:
import torch

# mul 点乘

## 标量 * 矩阵

In [30]:
t1 = torch.randn(1, 1)
t2 = torch.randn(2, 2)
t1, t2

(tensor([[-1.3066]]),
 tensor([[-0.6392, -0.8061],
         [ 1.5133,  0.0743]]))

In [31]:
t1 * t2

tensor([[ 0.8352,  1.0533],
        [-1.9773, -0.0971]])

In [32]:
torch.mul(t1, t2)

tensor([[ 0.8352,  1.0533],
        [-1.9773, -0.0971]])

In [18]:
(t1 * t2) == (torch.mul(t1, t2))

tensor([[True, True],
        [True, True]])

## 矩阵 * 矩阵 两个矩阵最后维度必须相同
- [a, b, c, d] * [d]
- [a, b, c, d] * [c, d]
- [a, b, c, d] * [b, c, d]
- [a, b, c, d] * [a, b, c, d]

In [23]:
(torch.randn(2, 2) * torch.randn(1, 2)).shape

torch.Size([2, 2])

In [25]:
(torch.randn(2, 2) * torch.randn(2, 2)).shape

torch.Size([2, 2])

In [26]:
(torch.randn(1, 2, 3, 4) * torch.randn(4)).shape

torch.Size([1, 2, 3, 4])

In [27]:
(torch.randn(1, 2, 3, 4) * torch.randn(3, 4)).shape

torch.Size([1, 2, 3, 4])

In [28]:
(torch.randn(1, 2, 3, 4) * torch.randn(2, 3, 4)).shape

torch.Size([1, 2, 3, 4])

In [29]:
(torch.randn(1, 2, 3, 4) * torch.randn(1, 2, 3, 4)).shape

torch.Size([1, 2, 3, 4])

# 矩阵相乘

## @ / matmul

In [39]:
t1 = torch.randn(2, 3)
t2 = torch.randn(3, 4)

In [40]:
(t1 @ t2).shape

torch.Size([2, 4])

In [41]:
torch.matmul(t1, t2).shape

torch.Size([2, 4])

In [43]:
torch.all((t1 @ t2) == (torch.matmul(t1, t2)))

tensor(True)

## torch.mm(mat1, mat2) 只适用于二维矩阵相乘,不推荐

In [49]:
t1 = torch.randn(2, 3)
t2 = torch.randn(3, 4)

In [50]:
torch.mm(t1, t2).shape

torch.Size([2, 4])

In [51]:
(t1 @ t2).shape

torch.Size([2, 4])

In [54]:
(
    torch.all((t1 @ t2) == torch.mm(t1, t2)),
    torch.all(torch.matmul(t1, t2) == torch.mm(t1, t2)),
)

(tensor(True), tensor(True))

## torch.bmm(mat1, mat2) 只适用于3维矩阵相乘,前面代表batch

In [60]:
t1 = torch.randn(2, 3, 2)
t2 = torch.randn(2, 2, 3)
t1.shape, t2.shape

(torch.Size([2, 3, 2]), torch.Size([2, 2, 3]))

In [61]:
torch.bmm(t1, t2).shape

torch.Size([2, 3, 3])

In [62]:
(t1 @ t2).shape

torch.Size([2, 3, 3])

In [63]:
(
    torch.all((t1 @ t2) == torch.bmm(t1, t2)),
    torch.all(torch.matmul(t1, t2) == torch.bmm(t1, t2)),
)

(tensor(True), tensor(True))

## 高维矩阵相乘 `[..., i, j] @ [..., j, k] = [..., i, k]`

In [104]:
t1 = torch.randn(1, 2, 3, 4)
t2 = torch.randn(1, 2, 4, 5)
t3 = torch.randn(2, 4, 5)
t4 = torch.randn(4, 5)

In [105]:
(t1 @ t2).shape, torch.matmul(t1, t2).shape

(torch.Size([1, 2, 3, 5]), torch.Size([1, 2, 3, 5]))

In [106]:
(t1 @ t3).shape, torch.matmul(t1, t3).shape

(torch.Size([1, 2, 3, 5]), torch.Size([1, 2, 3, 5]))

In [107]:
(t1 @ t4).shape, torch.matmul(t1, t4).shape

(torch.Size([1, 2, 3, 5]), torch.Size([1, 2, 3, 5]))

In [108]:
t1 = torch.randn(3, 4)
t2 = torch.randn(4, 5)
t3 = torch.randn(2, 4, 5)
t4 = torch.randn(1, 2, 4, 5)

In [109]:
(t1 @ t2).shape, torch.matmul(t1, t2).shape

(torch.Size([3, 5]), torch.Size([3, 5]))

In [110]:
(t1 @ t3).shape, torch.matmul(t1, t3).shape

(torch.Size([2, 3, 5]), torch.Size([2, 3, 5]))

In [111]:
(t1 @ t4).shape, torch.matmul(t1, t4).shape

(torch.Size([1, 2, 3, 5]), torch.Size([1, 2, 3, 5]))