In [12]:
import torch
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

### torch.matmul

是用来对两个tensor的矩阵进行乘积
  - 如果tensor都是一维返回其点积(标量)
  - 如果两个参数都是二维的，则返回矩阵-矩阵乘积
  - 如果第一个参数是一维，第二个参数是二维，则为了矩阵乘法的目的，在其维度前添加 1。矩阵相乘后，前面的维度将被删除
  - 如果第一个参数是二维的，第二个参数是一维的，则返回矩阵向量乘积

In [13]:
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()

torch.Size([])

In [20]:
# matrix x vector
tensor1 = torch.randn(3, 4)
logging.info(f'tensor1: {tensor1}')
tensor2 = torch.randn(4)
logging.info(f'tensor2: {tensor2}')
result = torch.matmul(tensor1, tensor2)
logging.info(f'result: {result}')
result_size = result.size()
logging.info(f'result size: {result_size}')

2024-05-13 20:45:23,065 - INFO - tensor1: tensor([[ 0.3501,  0.3488,  0.1369,  1.8027],
        [ 1.6219, -0.2836, -0.3403, -2.1379],
        [-0.7798,  0.2807, -1.4226, -0.6640]])
2024-05-13 20:45:23,073 - INFO - tensor2: tensor([ 0.5485,  0.3208, -0.2034,  0.9330])
2024-05-13 20:45:23,073 - INFO - result: tensor([ 1.9580, -1.1269, -0.6678])
2024-05-13 20:45:23,081 - INFO - result size: torch.Size([3])


In [22]:
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
logging.info(f'tensor1: {tensor1}')
tensor2 = torch.randn(4)
logging.info(f'tensor2: {tensor2}')
result = torch.matmul(tensor1, tensor2)
logging.info(f'result: {result}')
result_size = result.size()
logging.info(f'result size: {result_size}')

2024-05-13 20:45:51,084 - INFO - tensor1: tensor([[[ 0.2357,  0.5761,  0.8934, -0.0079],
         [ 0.4015, -0.5655, -0.1581,  0.3875],
         [ 1.4332, -0.3094,  1.4427, -0.1985]],

        [[ 0.5763,  0.1448, -0.4362, -1.2500],
         [-1.2595,  0.1285,  0.2561, -0.1456],
         [ 0.6547,  0.7291, -0.8775, -3.1186]],

        [[-0.6534,  1.8665,  0.7400, -0.4123],
         [ 0.8981,  0.2849,  0.0479,  2.2511],
         [ 1.4180, -1.3560, -0.1808, -0.3535]],

        [[-0.4277,  0.7968, -0.8294, -0.2217],
         [-1.3969, -1.0473,  0.0531,  0.4833],
         [ 0.3600,  0.8056, -0.9741, -0.0374]],

        [[ 0.2764, -1.5071,  0.9104,  0.3201],
         [-0.9385,  0.2217, -0.9779,  0.2385],
         [ 0.8215, -1.3298, -0.7635, -0.2740]],

        [[ 0.0291, -0.3835, -0.4896,  0.0337],
         [-1.3517, -0.6486, -1.3467, -1.3427],
         [-0.7678,  0.2538,  0.1539, -1.3407]],

        [[ 0.6212,  0.5787, -0.4068, -1.7679],
         [ 1.3880, -1.4773,  0.6720,  1.7329],
      

In [23]:
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])

In [24]:
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])