## torch.matmul() vs. torch.bmm() vs. torch.mm

In [5]:
import torch

# ---- torch.matmul() (@) ---- #
# Supports vector inner product, broadcasting. Lot of debugging points.

# Vector * Vector -> Inner Product
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
result = torch.matmul(tensor1, tensor2)
print(f"Matmul: {tensor1.shape} & {tensor2.shape} -> {result.shape}")

# Matrix * Vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
result = torch.matmul(tensor1, tensor2)
print(f"Matmul: {tensor1.shape} & {tensor2.shape} -> {result.shape}")

# Batched matrix * Broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
result = torch.matmul(tensor1, tensor2)
print(f"Matmul: {tensor1.shape} & {tensor2.shape} -> {result.shape}")

# Batched matrix * Batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
tensor2 = tensor2[0].repeat(10, 1, 1)
result = torch.matmul(tensor1, tensor2)
print(f"Matmul: {tensor1.shape} & {tensor2.shape} -> {result.shape}")

# Batched matrix * Broadcasted matrix
tensor2 = tensor2[0]
_result = torch.matmul(tensor1, tensor2)
print(f"Matmul: {tensor1.shape} & {tensor2.shape} -> {_result.shape}")
assert torch.allclose(result, _result)
print("========================")

# ---- torch.mm() ---- #
# torch.matmul() supports broadcasting, whereas torch.mm() does not.
# Does not support batch matrix multiplication. Only computes 2D tensor matrix multiplication.
# It's recommended to use the torch.mm() instead of torch.matmul() to reduce the debugging points.
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(3, 2)
result = torch.mm(tensor1, tensor2)
print(f"mm: {tensor1.shape} & {tensor2.shape} -> {result.shape}")
assert torch.allclose(result, torch.matmul(tensor1, tensor2))
print("========================")

# ---- torch.bmm() ---- #
# torch.matmul() supports broadcasting, whereas torch.bmm() does not.
# Do support batch matrix multiplication.
# It's recommended to use the torch.bmm() instead of torch.matmul() to reduce the debugging points.
tensor1 = torch.randn(10, 2, 3)
tensor2 = torch.randn(10, 3, 2)
result = torch.bmm(tensor1, tensor2)
print(f"bmm: {tensor1.shape} & {tensor2.shape} -> {result.shape}")
print("========================")

# ---- torch.outer() ---- #
# Outer product
tensor1 = torch.arange(1, 4)  # (3, )
tensor2 = torch.arange(1, 3)  # (2, )
result = torch.outer(tensor1, tensor2)
print(f"outer: {tensor1.shape} & {tensor2.shape} -> {result.shape}")
print("========================")

Matmul: torch.Size([3]) & torch.Size([3]) -> torch.Size([])
Matmul: torch.Size([3, 4]) & torch.Size([4]) -> torch.Size([3])
Matmul: torch.Size([10, 3, 4]) & torch.Size([4]) -> torch.Size([10, 3])
Matmul: torch.Size([10, 3, 4]) & torch.Size([10, 4, 5]) -> torch.Size([10, 3, 5])
Matmul: torch.Size([10, 3, 4]) & torch.Size([4, 5]) -> torch.Size([10, 3, 5])
mm: torch.Size([2, 3]) & torch.Size([3, 2]) -> torch.Size([2, 2])
bmm: torch.Size([10, 2, 3]) & torch.Size([10, 3, 2]) -> torch.Size([10, 2, 2])
outer: torch.Size([3]) & torch.Size([2]) -> torch.Size([3, 2])


## torch.expand vs torch.repeat

In [22]:
# ---- `torch.repeat` copies the tensor ---- #
tensor1 = torch.tensor([1, 2, 3])

# repeat `dim=0` 4 times, `dim=1` 2 times
# (1, 3) -> (4, 6)
print(tensor1.repeat(4, 2), tensor1.repeat(4, 2).size())

# (1, 3) -> (2, 3)
print(tensor1.repeat(2, 1), tensor1.repeat(2, 1).size())

# torch.repeat(*sizes) -> sizes.shape == (1, ) -> *sizes = (1, n)
# (1, 3) -> (1, 6)
print(tensor1.repeat(2), tensor1.repeat(2).size())
print("========================")

# ---- `torch.expand` ---- #
# Only applicable at the dimension which is 1
tensor1 = torch.tensor([[1], [2], [3]])
# (3, 1) -> (3, 4)
print(tensor1.expand(3, 4), tensor1.expand(3, 4).size())

tensor2 = torch.randn(1, 2, 3)
# (1, 2, 3) -> (4, 2, 3)
print(tensor2)
print(tensor2.expand(4, -1, -1), tensor2.expand(4, -1, -1).size())
print("========================")

# torch.repeat vs. torch.expand
print(torch.allclose(tensor2.repeat(4, 1, 1), tensor2.expand(4, -1, -1)))

tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]]) torch.Size([4, 6])
tensor([[1, 2, 3],
        [1, 2, 3]]) torch.Size([2, 3])
tensor([1, 2, 3, 1, 2, 3]) torch.Size([6])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]]) torch.Size([3, 4])
tensor([[[ 0.4609, -0.1003, -0.7988],
         [-0.5215, -1.0076,  0.4091]]])
tensor([[[ 0.4609, -0.1003, -0.7988],
         [-0.5215, -1.0076,  0.4091]],

        [[ 0.4609, -0.1003, -0.7988],
         [-0.5215, -1.0076,  0.4091]],

        [[ 0.4609, -0.1003, -0.7988],
         [-0.5215, -1.0076,  0.4091]],

        [[ 0.4609, -0.1003, -0.7988],
         [-0.5215, -1.0076,  0.4091]]]) torch.Size([4, 2, 3])
True
