In [1]:
import torch

In [2]:
_ = torch.tensor([0.2126, 0.7152, 0.0722], names=["c"])

  _ = torch.tensor([0.2126, 0.7152, 0.0722], names=['c'])


In [7]:
img_t = torch.randn(3, 5, 5)

In [4]:
batch_t = torch.randn(2, 3, 5, 5)  # shape [batch, channels, rows, columns]

In [5]:
img_gray_naive = img_t.mean(-3)
img_gray_naive.shape

torch.Size([5, 5])

In [6]:
batch_gray_naive = batch_t.mean(-3)
batch_gray_naive.shape

torch.Size([2, 5, 5])

In [8]:
weights = torch.tensor([0.2126, 0.7152, 0.0722])

In [9]:
unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1)
unsqueezed_weights.shape, unsqueezed_weights

(torch.Size([3, 1, 1]),
 tensor([[[0.2126]],
 
         [[0.7152]],
 
         [[0.0722]]]))

In [10]:
weighted_img = img_t * unsqueezed_weights
weighted_img.shape

torch.Size([3, 5, 5])

In [11]:
weighted_batch = batch_t * unsqueezed_weights
weighted_batch.shape

torch.Size([2, 3, 5, 5])

In [12]:
img_gray_weighted = weighted_img.sum(-3)
batch_gray_weighted = weighted_batch.sum(-3)

In [13]:
img_gray_weighted.shape, batch_gray_weighted.shape

(torch.Size([5, 5]), torch.Size([2, 5, 5]))

In [14]:
img_t.shape, weights.shape

(torch.Size([3, 5, 5]), torch.Size([3]))

"爱因斯坦求和约定"，爱因斯坦在研究相对论时，曾经对冗余的求和公式做了一个简化版的约定。AI中的具体实现一般指np和torch中的einsum()。
不推荐使用，看起来较为直观，但更容易出错，且不易debug。

einsum()的四句口诀：
- 外部重复做乘积 ,

- 内部重复把数取 ,

- 从有到无要求和 ,

- 重复默认要丢弃.

In [15]:
img_gray_weighted_fancy = torch.einsum("...chw,c->...hw", img_t, weights)
img_gray_weighted_fancy.shape

torch.Size([5, 5])

In [16]:
batch_gray_weighted_fancy = torch.einsum("...chw,c->...hw", batch_t, weights)
batch_gray_weighted_fancy.shape

torch.Size([2, 5, 5])

In [18]:
weights_named = torch.tensor([0.2126, 0.7152, 0.0722], names=["channels"])
weights_named

tensor([0.2126, 0.7152, 0.0722], names=('channels',))

In [19]:
img_named = img_t.refine_names(..., "channels", "rows", "colums")
batch_named = batch_t.refine_names(..., "channels", "rows", "columns")

In [20]:
img_named

tensor([[[-0.8458, -1.5019,  2.1762, -1.8889,  0.3466],
         [ 1.1400, -0.3311,  0.0593, -0.5156, -0.2877],
         [ 0.0842,  1.4210, -0.6301,  2.1071,  0.2736],
         [-0.0924,  2.3741, -0.1490,  0.9867, -1.4424],
         [-0.3587,  0.3222,  0.2832, -1.3613, -0.8284]],

        [[-3.5773, -0.0695, -1.0679, -0.3833,  0.3665],
         [-0.1612,  0.7765,  1.4100,  2.0078, -1.9607],
         [-0.8928,  2.1080, -0.3435,  0.4030,  1.5907],
         [ 0.4806,  1.6486, -0.1582,  0.4909, -0.8611],
         [-0.8707, -0.5649,  0.4726,  0.7306, -3.1533]],

        [[-1.1492, -1.0514, -1.1450, -0.0886, -0.1075],
         [-0.9897, -0.7020,  1.3134, -1.4369,  0.1279],
         [-0.4654,  0.8840,  0.9332, -0.1327, -0.7569],
         [ 1.1022, -1.2850,  0.7903, -0.2787, -0.7794],
         [-0.6142,  0.4623,  2.1319,  0.1869,  0.7083]]],
       names=('channels', 'rows', 'colums'))

In [21]:
batch_named

tensor([[[[-0.5502,  1.0179, -0.9398,  1.2022,  1.4469],
          [-0.4110, -2.1574, -0.0588,  0.9639,  1.0047],
          [ 0.0604,  1.7600, -1.2831,  0.3098,  1.2640],
          [-1.1063,  0.0469,  1.0534,  0.5009, -0.5312],
          [ 0.0398, -0.9085,  0.0977,  0.4397,  0.6624]],

         [[ 0.3348, -0.2372, -1.0977,  0.3140, -0.5675],
          [ 0.7674, -0.7300,  0.5761,  0.6203,  0.2813],
          [ 0.4190,  0.3505,  0.0930, -0.3457, -1.4747],
          [-0.6163, -0.7263,  1.8851, -0.3143, -0.0369],
          [-1.2528, -0.9152,  0.8228,  0.4309, -0.7692]],

         [[-0.8637,  2.5586, -1.8656, -1.0413, -0.8514],
          [-1.5720,  0.4588,  1.2960,  0.7910,  1.6022],
          [ 0.8303, -0.6765, -0.4364, -1.5555,  0.5091],
          [ 1.1311, -0.6482, -1.4978,  0.5350, -0.8970],
          [-0.3692, -0.4460,  0.4988, -0.9703,  1.6142]]],


        [[[ 0.7030,  0.1718, -1.0258, -0.6740, -0.2337],
          [-0.2686,  0.5959, -0.5900, -0.1290,  0.3990],
          [-0.3384, -0.

In [22]:
temp = img_t
temp.shape

torch.Size([3, 5, 5])

In [23]:
temp[None]
temp.shape

torch.Size([3, 5, 5])

In [24]:
temp

tensor([[[-0.8458, -1.5019,  2.1762, -1.8889,  0.3466],
         [ 1.1400, -0.3311,  0.0593, -0.5156, -0.2877],
         [ 0.0842,  1.4210, -0.6301,  2.1071,  0.2736],
         [-0.0924,  2.3741, -0.1490,  0.9867, -1.4424],
         [-0.3587,  0.3222,  0.2832, -1.3613, -0.8284]],

        [[-3.5773, -0.0695, -1.0679, -0.3833,  0.3665],
         [-0.1612,  0.7765,  1.4100,  2.0078, -1.9607],
         [-0.8928,  2.1080, -0.3435,  0.4030,  1.5907],
         [ 0.4806,  1.6486, -0.1582,  0.4909, -0.8611],
         [-0.8707, -0.5649,  0.4726,  0.7306, -3.1533]],

        [[-1.1492, -1.0514, -1.1450, -0.0886, -0.1075],
         [-0.9897, -0.7020,  1.3134, -1.4369,  0.1279],
         [-0.4654,  0.8840,  0.9332, -0.1327, -0.7569],
         [ 1.1022, -1.2850,  0.7903, -0.2787, -0.7794],
         [-0.6142,  0.4623,  2.1319,  0.1869,  0.7083]]])

In [25]:
temp[0, 0, 0] = 1
temp, img_t

(tensor([[[ 1.0000, -1.5019,  2.1762, -1.8889,  0.3466],
          [ 1.1400, -0.3311,  0.0593, -0.5156, -0.2877],
          [ 0.0842,  1.4210, -0.6301,  2.1071,  0.2736],
          [-0.0924,  2.3741, -0.1490,  0.9867, -1.4424],
          [-0.3587,  0.3222,  0.2832, -1.3613, -0.8284]],
 
         [[-3.5773, -0.0695, -1.0679, -0.3833,  0.3665],
          [-0.1612,  0.7765,  1.4100,  2.0078, -1.9607],
          [-0.8928,  2.1080, -0.3435,  0.4030,  1.5907],
          [ 0.4806,  1.6486, -0.1582,  0.4909, -0.8611],
          [-0.8707, -0.5649,  0.4726,  0.7306, -3.1533]],
 
         [[-1.1492, -1.0514, -1.1450, -0.0886, -0.1075],
          [-0.9897, -0.7020,  1.3134, -1.4369,  0.1279],
          [-0.4654,  0.8840,  0.9332, -0.1327, -0.7569],
          [ 1.1022, -1.2850,  0.7903, -0.2787, -0.7794],
          [-0.6142,  0.4623,  2.1319,  0.1869,  0.7083]]]),
 tensor([[[ 1.0000, -1.5019,  2.1762, -1.8889,  0.3466],
          [ 1.1400, -0.3311,  0.0593, -0.5156, -0.2877],
          [ 0.0842,  1

None作为占位符，可以用来扩展tensor的维度，直接使用`tensor = tensor[None]`将在高维增加一维，使用`tensor = tensor[..., None]`将在低维增加一维

In [34]:
points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points = points[None]
points.shape

torch.Size([1, 3, 2])

In [35]:
points = points[..., None]
points.shape

torch.Size([1, 3, 2, 1])

In [39]:
weights_aligned = weights_named.align_as(img_named)
weights_aligned.shape

torch.Size([3, 1, 1])

In [46]:
img_t.shape

torch.Size([3, 5, 5])

In [50]:
test = torch.ones_like(img_t)
test_weights = torch.tensor([1, 2, 3], names=["channels"])
test, test_weights

(tensor([[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]],
 
         [[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]],
 
         [[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]),
 tensor([1, 2, 3], names=('channels',)))

In [52]:
test = test.refine_names(..., "channels", "rows", "columns")
test_weights = test_weights.align_as(test)
test.shape, test_weights.shape

(torch.Size([3, 5, 5]), torch.Size([3, 1, 1]))

In [53]:
test * test_weights

tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]],

        [[3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3.],
         [3., 3., 3., 3., 3.]]], names=('channels', 'rows', 'columns'))

当形状为[3,4,5]的张量与形状为[3,1,1]的张量相乘时，计算逻辑如下：

1. 首先，将形状为[3,1,1]的张量进行广播，使其形状与[3,4,5]的张量相匹配。广播的规则是在缺失的维度上将维度大小为1的张量进行复制，直到两个张量的形状相同。

2. 然后，对应位置的元素进行相乘。由于广播后两个张量的形状相同，所以可以直接对应位置的元素相乘。

3. 最后，得到的结果是一个形状为[3,4,5]的张量，其中每个元素是原始张量对应位置元素的乘积。

需要注意的是，广播操作只会复制元素的值，并不会增加张量的内存占用。因此，计算逻辑是在不增加内存占用的情况下对应位置的元素相乘。

In [37]:
img_named * weights_aligned

tensor([[[ 0.2126, -0.3193,  0.4627, -0.4016,  0.0737],
         [ 0.2424, -0.0704,  0.0126, -0.1096, -0.0612],
         [ 0.0179,  0.3021, -0.1340,  0.4480,  0.0582],
         [-0.0197,  0.5047, -0.0317,  0.2098, -0.3067],
         [-0.0763,  0.0685,  0.0602, -0.2894, -0.1761]],

        [[-2.5585, -0.0497, -0.7638, -0.2741,  0.2621],
         [-0.1153,  0.5554,  1.0084,  1.4360, -1.4023],
         [-0.6385,  1.5077, -0.2457,  0.2882,  1.1377],
         [ 0.3438,  1.1791, -0.1131,  0.3511, -0.6158],
         [-0.6227, -0.4040,  0.3380,  0.5225, -2.2552]],

        [[-0.0830, -0.0759, -0.0827, -0.0064, -0.0078],
         [-0.0715, -0.0507,  0.0948, -0.1037,  0.0092],
         [-0.0336,  0.0638,  0.0674, -0.0096, -0.0546],
         [ 0.0796, -0.0928,  0.0571, -0.0201, -0.0563],
         [-0.0443,  0.0334,  0.1539,  0.0135,  0.0511]]],
       names=('channels', 'rows', 'colums'))

这里的img_named[...,:3]是为了在与weights_named相乘时让后者可以广播，也就是让二者满足相乘的条件。
但是可以看到，二者依旧无法相乘，因为二者的最低维命名不匹配。

In [59]:
try:
    gray_named = (img_named[..., :3] * weights_named).sum("channels")
except Exception as e:
    print(e)

Error when attempting to broadcast dims ['channels', 'rows', 'colums'] and dims ['channels']: dim 'colums' and dim 'channels' are at the same position from the right but do not match.


torch中的*等价于torch.mul()，本质上是element-wise的乘法，也就是逐元素进行乘法，尺寸不统一的先根据广播机制扩展至相同尺寸，再进行对应元素的乘法。

In [57]:
a = torch.tensor([2, 3, 4, 5])
b = torch.tensor([2, 3, 4, 5])
a * b

tensor([ 4,  9, 16, 25])

In [63]:
gray_named = (img_named * weights_aligned).sum("channels")
gray_named.shape, gray_named.names

(torch.Size([5, 5]), ('rows', 'colums'))

In [64]:
gray_plain = gray_named.rename(None)
gray_plain.shape, gray_plain.names

(torch.Size([5, 5]), (None, None))