## torch.gather
Gathers values along an axis specified by dim.
沿着某一维度取特定的值
input 和 index 维度要相同

In [1]:
import torch
from pprint import pprint


In [3]:
t = torch.arange(9, dtype=torch.float32).reshape((3,3))

In [4]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

**out[i][j][k] = input[index[i][j][k]][j][k]**  # if dim == 0 2d
**out[i][j] = input[index[i][j]][j]**  # if dim == 0 3d

In [5]:
gather_t_dim_0 = torch.gather(t, dim=0, index=torch.tensor([[1,2],[2,1],[0,1]])) # -> dim=0表示按行取, index[0][0]=1,index[0][1]=2
gather_t_dim_1 = torch.gather(t, dim=1, index=torch.tensor([[1,2],[2,1],[0,1]]))

In [6]:
print("gather_t_dim_0:",gather_t_dim_0)
print("gather_t_dim_1:",gather_t_dim_1)

gather_t_dim_0: tensor([[3., 7.],
        [6., 4.],
        [0., 4.]])
gather_t_dim_1: tensor([[1., 2.],
        [5., 4.],
        [6., 7.]])


## Torch.scatter

Writes all values from the tensor into at the indices specified in the tensor. For each value in , its output index is specified by its index in for and by the corresponding value in for .srcselfindexsrcsrcdimension != dimindexdimension = dim

用法与 ***torch.gather*** 类似, 多了一个 ***src*** , 会把 ***src*** 里的元素放到 ***input*** 里

**self[index[i][j][k]][j][k] = src[i][j][k]**  # if dim == 0
**self[index[i][j]][j] = src[i][j]**  # if dim == 0  2 dimension

In [7]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

In [8]:
# dim=0表示按行取, index[0][0]=1,index[0][1]=2
scatter_t_dim_0 = torch.scatter(input=torch.zeros((3,5)), dim=0, src=t,index=torch.tensor([[1,2,1],[0,2,1],[0,1,2]]))

In [9]:
scatter_t_dim_0

tensor([[6., 0., 0., 0., 0.],
        [0., 7., 5., 0., 0.],
        [0., 4., 8., 0., 0.]])

## torch.split

split_size_or_sections: 可以是一个列表，如果是列表的话就按照列表里的数来依次划分

In [10]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

In [12]:
t.split(split_size=[1,2], dim=0)

(tensor([[0., 1., 2.]]),
 tensor([[3., 4., 5.],
         [6., 7., 8.]]))

## torch.stack  VS  torch.cat

#### torch.stack
- input为sequence of tensor(可以是list of tensor) **要求所有的张量维度必须一样！！**
- stack会在指定的dim新增一个维度，如果指定dim=0，那么则在第0维新增

#### torch.cat
- cat相比stack不会新增维度，而是会在指定的dim拼接，如果指定dim=0，那么则在第0维拼接

In [58]:
t2 = torch.arange(1,10, dtype=torch.float32).reshape([3,3])

In [59]:
print(t.shape)
print(t2.shape)
print(t)
print(t2)

torch.Size([3, 3])
torch.Size([3, 3])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])


In [60]:
t_stack_dim_0 = torch.stack([t,t2],dim=0)
t_stack_dim_1 = torch.stack([t,t2],dim=1)

In [61]:
print(t_stack_dim_0.shape)
print(t_stack_dim_1.shape)
print(t_stack_dim_0)
print(t_stack_dim_1)

torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]],

        [[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
tensor([[[0., 1., 2.],
         [1., 2., 3.]],

        [[3., 4., 5.],
         [4., 5., 6.]],

        [[6., 7., 8.],
         [7., 8., 9.]]])


In [62]:
t_cat_dim_0 = torch.cat([t,t2],dim=0)
t_cat_dim_1 = torch.cat([t,t2],dim=1)

In [64]:
print(t_cat_dim_0.shape)
print(t_cat_dim_1.shape)
print(t_cat_dim_0)
print(t_cat_dim_1)

torch.Size([6, 3])
torch.Size([3, 6])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.],
        [1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
tensor([[0., 1., 2., 1., 2., 3.],
        [3., 4., 5., 4., 5., 6.],
        [6., 7., 8., 7., 8., 9.]])


In [9]:
t = torch.randn(3,2,5,7)
t1 = torch.randn(3,2,7,5)

In [11]:
t2 = torch.matmul(t, t1)

In [12]:
t2.shape

torch.Size([3, 2, 5, 5])

In [13]:
t3 = torch.mm(t,t1)
from timm.models import VisionTransformer

RuntimeError: self must be a matrix