## torch.gather
Gathers values along an axis specified by dim.
沿着某一维度取特定的值
input 和 index 维度要相同

In [39]:
import torch
from pprint import pprint


In [41]:
'./outputs/ent_model.pth'.format(2)

'./outputs/ent_model.pth'

In [2]:
t = torch.arange(9, dtype=torch.float32).reshape((3, 3))

In [3]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

**out[i][j][k] = input[index[i][j][k]][j][k]**  # if dim == 0 2d
**out[i][j] = input[index[i][j]][j]**  # if dim == 0 3d

In [4]:
gather_t_dim_0 = torch.gather(t, dim=0, index=torch.tensor(
    [[1, 2], [2, 1], [0, 1]]))  # -> dim=0表示按行取, index[0][0]=1,index[0][1]=2
gather_t_dim_1 = torch.gather(t, dim=1, index=torch.tensor([[1, 2], [2, 1], [0, 1]]))

In [5]:
print("gather_t_dim_0:", gather_t_dim_0)
print("gather_t_dim_1:", gather_t_dim_1)

gather_t_dim_0: tensor([[3., 7.],
        [6., 4.],
        [0., 4.]])
gather_t_dim_1: tensor([[1., 2.],
        [5., 4.],
        [6., 7.]])


## Torch.scatter

Writes all values from the tensor into at the indices specified in the tensor. For each value in , its output index is specified by its index in for and by the corresponding value in for .srcselfindexsrcsrcdimension != dimindexdimension = dim

用法与 ***torch.gather*** 类似, 多了一个 ***src*** , 会把 ***src*** 里的元素放到 ***input*** 里

**self[index[i][j][k]][j][k] = src[i][j][k]**  # if dim == 0
**self[index[i][j]][j] = src[i][j]**  # if dim == 0  2 dimension

In [6]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

In [7]:
# dim=0表示按行取, index[0][0]=1,index[0][1]=2
scatter_t_dim_0 = torch.scatter(input=torch.zeros((3, 5)), dim=0, src=t,
                                index=torch.tensor([[1, 2, 1], [0, 2, 1], [0, 1, 2]]))

In [8]:
scatter_t_dim_0

tensor([[6., 0., 0., 0., 0.],
        [0., 7., 5., 0., 0.],
        [0., 4., 8., 0., 0.]])

## torch.split

split_size_or_sections: 可以是一个列表，如果是列表的话就按照列表里的数来依次划分

In [9]:
t

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

In [10]:
t.split(split_size=[1, 2], dim=0)

(tensor([[0., 1., 2.]]),
 tensor([[3., 4., 5.],
         [6., 7., 8.]]))

In [36]:
a = torch.rand((1, 150, 64))
b = torch.rand((1, 150, 64))
c = torch.arange(4).reshape((1, 1, 1, 4))
d = torch.arange(27).reshape((3, 3, 3))

In [21]:
torch.triu(d, diagonal=1)

tensor([[[ 0,  1,  2],
         [ 0,  0,  5],
         [ 0,  0,  0]],

        [[ 0, 10, 11],
         [ 0,  0, 14],
         [ 0,  0,  0]],

        [[ 0, 19, 20],
         [ 0,  0, 23],
         [ 0,  0,  0]]])

In [25]:
torch.nonzero(torch.Tensor([[0.6, 0.0, 0.0, 0.0],
                            [0.0, 0.0, 0.0, 0.0],
                            [0.0, 0.0, 1.2, 0.0],
                            [0.0, 0.0, 0.0, -0.4]]))


tensor([[0, 0],
        [2, 2],
        [3, 3]])

In [20]:
a.shape

torch.Size([1, 150])

In [20]:
torch.einsum('bn,d -> bnd', c, d)

tensor([[[ 0,  0,  0,  0,  0],
         [ 0,  1,  2,  3,  4],
         [ 0,  2,  4,  6,  8],
         [ 0,  3,  6,  9, 12]],

        [[ 0,  4,  8, 12, 16],
         [ 0,  5, 10, 15, 20],
         [ 0,  6, 12, 18, 24],
         [ 0,  7, 14, 21, 28]],

        [[ 0,  8, 16, 24, 32],
         [ 0,  9, 18, 27, 36],
         [ 0, 10, 20, 30, 40],
         [ 0, 11, 22, 33, 44]]])

## torch.stack  VS  torch.cat

#### torch.stack
- input为sequence of tensor(可以是list of tensor) **要求所有的张量维度必须一样！！**
- stack会在指定的dim新增一个维度，如果指定dim=0，那么则在第0维新增

#### torch.cat
- cat相比stack不会新增维度，而是会在指定的dim拼接，如果指定dim=0，那么则在第0维拼接

In [58]:
t2 = torch.arange(1, 10, dtype=torch.float32).reshape([3, 3])

In [59]:
print(t.shape)
print(t2.shape)
print(t)
print(t2)

torch.Size([3, 3])
torch.Size([3, 3])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])


In [60]:
t_stack_dim_0 = torch.stack([t, t2], dim=0)
t_stack_dim_1 = torch.stack([t, t2], dim=1)

In [61]:
print(t_stack_dim_0.shape)
print(t_stack_dim_1.shape)
print(t_stack_dim_0)
print(t_stack_dim_1)

torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]],

        [[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
tensor([[[0., 1., 2.],
         [1., 2., 3.]],

        [[3., 4., 5.],
         [4., 5., 6.]],

        [[6., 7., 8.],
         [7., 8., 9.]]])


In [62]:
t_cat_dim_0 = torch.cat([t, t2], dim=0)
t_cat_dim_1 = torch.cat([t, t2], dim=1)

In [64]:
print(t_cat_dim_0.shape)
print(t_cat_dim_1.shape)
print(t_cat_dim_0)
print(t_cat_dim_1)

torch.Size([6, 3])
torch.Size([3, 6])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.],
        [1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
tensor([[0., 1., 2., 1., 2., 3.],
        [3., 4., 5., 4., 5., 6.],
        [6., 7., 8., 7., 8., 9.]])


In [9]:
t = torch.randn(3, 2, 5, 7)
t1 = torch.randn(3, 2, 7, 5)

In [11]:
t2 = torch.matmul(t, t1)

In [12]:
t2.shape

torch.Size([3, 2, 5, 5])

In [6]:
a = torch.rand(2, 3, 6)
b = torch.rand(3, 4, 3)
c = torch.einsum("iko,kjk->ijo", [a, b])

In [7]:
c.shape

torch.Size([2, 4, 6])

In [19]:
t = torch.arange(12, dtype=torch.float32).reshape((3, 4))
t1 = torch.arange(5, dtype=torch.float32)
a = torch.einsum('bd,k->bdk', [t, t1])

In [20]:
t = t.unsqueeze(-1)

In [21]:
t.shape

torch.Size([3, 4, 1])

In [22]:
t1 = t1.unsqueeze(0)

In [23]:
b = torch.matmul(t, t1)

In [24]:
torch.allclose(a, b)

True