# 一、PyTorch环境检查

In [91]:
import torch

print(torch.__version__)
print("cuda:", torch.cuda.is_available())

1.8.0+cu101
cuda: True


# 二、查看张量类型

In [92]:
import torch

a = torch.randn(2, 3)
b = torch.randint(0, 1, (2, 3))
print(a.type())
print(b.type())
print(type(a))
print(type(b))
print(isinstance(a, torch.FloatTensor))
print(isinstance(b, torch.FloatTensor))

torch.FloatTensor
torch.LongTensor
<class 'torch.Tensor'>
<class 'torch.Tensor'>
True
False


# 三、查看张量尺寸

In [93]:
import torch

a = torch.randn(2, 3)
print(a.size(), type(a.size()))
print(a.shape, type(a.shape))
print("维度数:", a.dim())
print("所占内存大小:", a.numel())

torch.Size([2, 3]) <class 'torch.Size'>
torch.Size([2, 3]) <class 'torch.Size'>
维度数: 2
所占内存大小: 6


# 四、创建张量

## 4.1 生成值全为1的张量

In [94]:
import torch

a = torch.ones(2, 3)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.]])


## 4.2 生成值全为0的张量

In [95]:
import torch

a = torch.zeros(2, 3)
print(a)

tensor([[0., 0., 0.],
        [0., 0., 0.]])


## 4.3 生成值全为指定值的张量

In [96]:
import torch

a = torch.full([2, 3], 6.6)
print(a)
print(a.shape)

a = torch.full([], 6.6)
print(a)
print(a.shape)

tensor([[6.6000, 6.6000, 6.6000],
        [6.6000, 6.6000, 6.6000]])
torch.Size([2, 3])
tensor(6.6000)
torch.Size([])


## 4.4 通过list创建张量

In [97]:
import torch

print(torch.LongTensor([[1, 2], [3, 4]]))
print(torch.Tensor([[1, 2], [3, 4]]))
print(torch.FloatTensor([[1, 2], [3, 4]]))

tensor([[1, 2],
        [3, 4]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


## 4.5 通过 ndarray 创建张量

In [98]:
import torch
import numpy as np

a = np.array([2, 3.3])
print(type(a))

print(torch.from_numpy(a))

<class 'numpy.ndarray'>
tensor([2.0000, 3.3000], dtype=torch.float64)


## 4.6 创建指定范围和间距的有序张量

In [99]:
import torch

print("torch.arange(0,10):", torch.arange(0, 10))
print("torch.arange(0,10,2):", torch.arange(0, 10, 2))
print("torch.linspace(0,10,steps = 4):", torch.linspace(0, 10, steps=4))
print("torch.linspace(0,10,steps = 10):", torch.linspace(0, 10, steps=10))
print("torch.linspace(0,10,steps = 11):", torch.linspace(0, 10, steps=11))
print("torch.logspace(0,-1,steps = 10):", torch.logspace(0, -1, steps=10))
print("torch.logspace(0,1,steps = 10):", torch.logspace(0, 1, steps=10))

torch.arange(0,10): tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.arange(0,10,2): tensor([0, 2, 4, 6, 8])
torch.linspace(0,10,steps = 4): tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
torch.linspace(0,10,steps = 10): tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,
         8.8889, 10.0000])
torch.linspace(0,10,steps = 11): tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
torch.logspace(0,-1,steps = 10): tensor([1.0000, 0.7743, 0.5995, 0.4642, 0.3594, 0.2783, 0.2154, 0.1668, 0.1292,
        0.1000])
torch.logspace(0,1,steps = 10): tensor([ 1.0000,  1.2915,  1.6681,  2.1544,  2.7826,  3.5938,  4.6416,  5.9948,
         7.7426, 10.0000])


## 4.7 创建单位矩阵（对角线为1）

In [100]:
import torch

# n * n
print(torch.eye(3))
print(torch.eye(4, 4))

# 非 n * n
print(torch.eye(2, 3))

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
tensor([[1., 0., 0.],
        [0., 1., 0.]])


# 五、生成随机张量

## 5.1 按均匀分布生成

In [101]:
import torch

# 生成shape为(2,3,2)的Tensor
random_tensor = torch.rand(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

tensor([[[0.4632, 0.2396],
         [0.0917, 0.9208],
         [0.0559, 0.2765]],

        [[0.7071, 0.8135],
         [0.4985, 0.3682],
         [0.2049, 0.1780]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])


## 5.2 按标准正态分布生成

In [102]:
import torch

# 生成shape为(2,3,2)的Tensor
random_tensor = torch.randn(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

tensor([[[-0.6808, -0.0704],
         [ 1.4618, -1.2485],
         [ 0.4092,  0.7897]],

        [[ 1.2597, -0.1547],
         [ 0.1840,  0.3599],
         [ 0.4327,  0.1597]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])


## 5.3 生成指定区间的整型随机张量

In [103]:
import torch

# 生成shape为(2,3,2)的Tensor
# 整数范围[1,4)
random_tensor = torch.randint(1, 4, (2, 3, 2))
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

tensor([[[3, 1],
         [1, 3],
         [3, 1]],

        [[3, 2],
         [3, 3],
         [1, 1]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])


## 5.4 获取随机序列

In [104]:
import torch

# torch中没有random.shuffle
# y = torch.randperm(n) y是把0到n-1这些数随机打乱得到的一个数字序列
# randperm(n, out=None, dtype=torch.int64)-> LongTensor
idx = torch.randperm(3)
a = torch.Tensor(4, 2)
print(a)
print(idx, idx.type())
print(a[idx])

tensor([[0.0000e+00, 4.5916e-41],
        [1.6411e-07, 8.2732e+20],
        [1.6689e-07, 1.3167e-08],
        [1.6918e-04, 2.1630e+23]])
tensor([1, 0, 2]) torch.LongTensor
tensor([[1.6411e-07, 8.2732e+20],
        [0.0000e+00, 4.5916e-41],
        [1.6689e-07, 1.3167e-08]])


# 六、张量的索引与切片

## 6.1 索引

In [105]:
import torch

a = torch.rand(4, 3, 28, 28)
print("a[0].shape:", a[0].shape)
print("a[0,0].shape:", a[0, 0].shape)
print("a[0,0,2,4]:", a[0, 0, 2, 4])

a[0].shape: torch.Size([3, 28, 28])
a[0,0].shape: torch.Size([28, 28])
a[0,0,2,4]: tensor(0.4791)


## 6.2 切片

### 6.2.1 获取张量的前/后N个元素

In [106]:
import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a[:2].shape:", a[:2].shape)
print("a[:2,:1,:,:].shape:", a[:2, :1, :, :].shape)
print("a[:2,1:,:,:].shape:", a[:2, 1:, :, :].shape)
print("a[:2,-1:,:,:].shape:", a[:2, -1:, :, :].shape)

a.shape: torch.Size([4, 3, 28, 28])
a[:2].shape: torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])


### 6.2.2 根据指定步长获取张量的前/后N个元素

In [107]:
import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a[:,:,0:28:2,0:28:2].shape:", a[:, :, 0:28:2, 0:28:2].shape)
print("a[:,:,::2,::2].shape:", a[:, :, ::2, ::2].shape)

a.shape: torch.Size([4, 3, 28, 28])
a[:,:,0:28:2,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape: torch.Size([4, 3, 14, 14])


### 6.2.3 根据特殊索引获取张量值

In [108]:
import torch

a = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)

print("a.index_select(0,torch.tensor([0,2])).shape:", a.index_select(0, torch.tensor([0, 2])).shape)
print("a.index_select(1,torch.tensor([1,2])).shape:", a.index_select(1, torch.tensor([1, 2])).shape)
print("a.index_select(2,torch.arange(28)).shape:", a.index_select(2, torch.arange(28)).shape)
print("a.index_select(2,torch.arange(8)).shape:", a.index_select(2, torch.arange(8)).shape)

print("a[...].shape:", a[...].shape)
print("a[0,...].shape:", a[0, ...].shape)
print("a[:,1,...].shape:", a[:, 1, ...].shape)
print("a[...,:2].shape:", a[..., :2].shape)

a.shape: torch.Size([4, 3, 28, 28])
a.index_select(0,torch.tensor([0,2])).shape: torch.Size([2, 3, 28, 28])
a.index_select(1,torch.tensor([1,2])).shape: torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape: torch.Size([4, 3, 28, 28])
a.index_select(2,torch.arange(8)).shape: torch.Size([4, 3, 8, 28])
a[...].shape: torch.Size([4, 3, 28, 28])
a[0,...].shape: torch.Size([3, 28, 28])
a[:,1,...].shape: torch.Size([4, 28, 28])
a[...,:2].shape: torch.Size([4, 3, 28, 2])


### 6.2.4 根据 mask 选取张量值

In [109]:
import torch

a = torch.rand(3, 4)
print(a)

mask = a.ge(0.5)
print(mask)

b = torch.masked_select(a, mask)
print(b)
print(b.shape)

tensor([[0.4041, 0.6005, 0.1984, 0.1672],
        [0.1053, 0.8833, 0.1248, 0.9476],
        [0.6271, 0.6862, 0.7329, 0.0712]])
tensor([[False,  True, False, False],
        [False,  True, False,  True],
        [ True,  True,  True, False]])
tensor([0.6005, 0.8833, 0.9476, 0.6271, 0.6862, 0.7329])
torch.Size([6])


### 6.2.5 根据展平的索引获取张量值

In [110]:
import torch

a = torch.Tensor([[4, 3, 5], [6, 7, 8]])
print(a)
print(torch.take(a, torch.tensor([0, 2, -1])))

tensor([[4., 3., 5.],
        [6., 7., 8.]])
tensor([4., 5., 8.])


# 七、张量的维度变换

## 7.1 view 和 reshape

In [111]:
import torch

a = torch.rand(4, 1, 28, 28)
print(a.shape)

print(a.view(4, 28 * 28).shape)
print(a.view(4 * 28, 28).shape)
print(a.view(4, 28, 28, 1).shape)

torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([112, 28])
torch.Size([4, 28, 28, 1])


## 7.2 unsqueeze 升维

In [112]:
import torch

a = torch.rand(4, 1, 28, 28)
print("a.shape:", a.shape)

print("a.unsqueeze(0).shape:", a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape:", a.unsqueeze(-1).shape)

b = torch.rand(32)
print("b.shape:", b.shape)
print("b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape:", b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)

a.shape: torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape: torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape: torch.Size([4, 1, 28, 28, 1])
b.shape: torch.Size([32])
b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape: torch.Size([1, 32, 1, 1])


## 7.3 squeeze 降维

In [113]:
import torch

b = torch.rand(4, 1, 28, 28)
print("b.shape:", b.shape)

print("b.squeeze().shape:", b.squeeze().shape)
print("b.squeeze(0).shape:", b.squeeze(0).shape)
print("b.squeeze(-1).shape:", b.squeeze(-1).shape)

b.shape: torch.Size([4, 1, 28, 28])
b.squeeze().shape: torch.Size([4, 28, 28])
b.squeeze(0).shape: torch.Size([4, 1, 28, 28])
b.squeeze(-1).shape: torch.Size([4, 1, 28, 28])


## 7.4 expand

In [114]:
import torch

b = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)

print("b.expand(4,32,14,14).shape:", b.expand(4, 32, 14, 14).shape)
print("b.expand(-1,32,-1,-1).shape:", b.expand(-1, 32, -1, -1).shape)
print("b.expand(-1,32,-1,4).shape:", b.expand(-1, 32, -1, 4).shape)

b.shape: torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape: torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape: torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,4).shape: torch.Size([1, 32, 1, 4])


## 7.5 repeat

In [115]:
import torch

b = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)

print("b.repeat(4,32,1,1).shape:", b.repeat(4, 32, 1, 1).shape)
print("b.repeat(4,1,1,1).shape:", b.repeat(4, 1, 1, 1).shape)
print("b.repeat(4,1,32,32).shape:", b.repeat(4, 1, 32, 32).shape)

b.shape: torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape: torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape: torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape: torch.Size([4, 32, 32, 32])


## 7.6 .t()转置

In [116]:
import torch

b = torch.rand(3, 4)
print(b)
print(b.t())

tensor([[0.0218, 0.0356, 0.9472, 0.4952],
        [0.6545, 0.3335, 0.9319, 0.7809],
        [0.7616, 0.7047, 0.6538, 0.6556]])
tensor([[0.0218, 0.6545, 0.7616],
        [0.0356, 0.3335, 0.7047],
        [0.9472, 0.9319, 0.6538],
        [0.4952, 0.7809, 0.6556]])


## 7.7 transpose 维度变换

In [117]:
import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.transpose(1, 3).shape)

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])


## 7.8 permute 维度变换

In [118]:
import torch

a = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.permute(0, 2, 3, 1).shape)

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])


# 八、张量的拼接和拆分

## 8.1 cat

In [119]:
import torch

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0)
print(c.shape)

torch.Size([9, 32, 8])


## 8.2 stack

In [120]:
import torch

a1 = torch.rand(4, 3, 16, 32)
a2 = torch.rand(4, 3, 16, 32)
c = torch.stack([a1, a2], dim=2)
print(c.shape)

torch.Size([4, 3, 2, 16, 32])


## 8.3 split

In [121]:
import torch

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])

aa, bb = c.split([1, 1], dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

aa, bb = c.split([20, 12], dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 20, 8]) torch.Size([2, 12, 8])

torch.Size([2, 32, 8])
torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
torch.Size([2, 20, 8]) torch.Size([2, 12, 8])


## 8.4 chunk

In [122]:
import torch

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])

aa, bb = c.chunk(2, dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

aa, bb = c.chunk(2, dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 16, 8]) torch.Size([2, 16, 8])

aa, bb, cc, dd = c.chunk(4, dim=1)
print(aa.shape, bb.shape, cc.shape,
      dd.shape)  #torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8])

torch.Size([2, 32, 8])
torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
torch.Size([2, 16, 8]) torch.Size([2, 16, 8])
torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8])


# 九、基本运算

## 9.1 广播机制

In [123]:
import torch

a = torch.rand(2, 2)
print(a)
b = torch.rand(2)
print(b)
print(a + b)

tensor([[0.5992, 0.2061],
        [0.9040, 0.6210]])
tensor([0.8494, 0.9617])
tensor([[1.4486, 1.1678],
        [1.7535, 1.5827]])


## 9.2 matmul 矩阵/张量乘法

In [127]:
import torch

a = torch.ones(2, 2) * 3
b = torch.ones(2, 2)
print(a)
print(b)
print(torch.matmul(a, b))

tensor([[3., 3.],
        [3., 3.]])
tensor([[1., 1.],
        [1., 1.]])
tensor([[6., 6.],
        [6., 6.]])


## 9.3 次方运算

In [128]:
import torch

a = torch.ones(2, 2) * 3
print(a)
print(torch.pow(a, 3))

tensor([[3., 3.],
        [3., 3.]])
tensor([[27., 27.],
        [27., 27.]])


## 9.4 sqrt 平方根运算

In [131]:
import torch

a = torch.ones(2, 2) * 9
print(a)
print(torch.pow(a, 0.5))
print(torch.sqrt(a))

tensor([[9., 9.],
        [9., 9.]])
tensor([[3., 3.],
        [3., 3.]])
tensor([[3., 3.],
        [3., 3.]])


## 9.5 exp 指数幂运算

In [133]:
import torch

a = torch.ones(2, 2)
print(a)
print(torch.exp(a))

tensor([[1., 1.],
        [1., 1.]])
tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])


## 9.6 log 对数运算

In [135]:
import torch

a = torch.ones(2, 2) * 3
print(a)
print(torch.log(a))

tensor([[3., 3.],
        [3., 3.]])
tensor([[1.0986, 1.0986],
        [1.0986, 1.0986]])


## 9.7 取整

In [136]:
import torch

a = torch.tensor(3.14)
print(a)  # tensor(3.1400)
print(torch.floor(a))  #tensor(3.)
print(torch.ceil(a))  #tensor(4.)
print(torch.round(a))  #tensor(3.)
print(torch.trunc(a))  #tensor(3.)
print(torch.frac(a))  #tensor(0.1400)

tensor(3.1400)
tensor(3.)
tensor(4.)
tensor(3.)
tensor(3.)
tensor(0.1400)


## 9.8 clamp 控制张量的取值范围

In [139]:
import torch

a = torch.rand(2, 3) * 15

print(a)
# 将大于8的值设置为8；小于4的值设置为4
print(torch.clamp(a, 4, 8))

tensor([[ 8.8872,  5.6534, 14.3027],
        [ 0.8305, 12.6266, 13.9683]])
tensor([[8.0000, 5.6534, 8.0000],
        [4.0000, 8.0000, 8.0000]])


# 十、统计属性

## 10.1 norm 求范数

In [140]:
import torch

a = torch.ones(2, 3)
b = torch.norm(a)  # 默认求2范数
c = torch.norm(a, p=1)  # 指定求1范数
print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor(2.4495)
tensor(6.)


## 10.2 mean、median、sum、min、max、prod、argmax、argmin

In [146]:
import torch

a = torch.arange(8).view(2, 4).float()
print(a)
'''
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]])
'''
print(a.mean())  #tensor(3.5000)
print(a.median())  #tensor(3.)
print(a.sum())  #tensor(28.)
print(a.min())  #tensor(0.)
print(a.max())  #tensor(7.)
print(a.prod())  #tensor(0.)
print(a.argmax())  #tensor(7)
print(a.argmin())  #tensor(0)

tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]])
tensor(3.5000)
tensor(3.)
tensor(28.)
tensor(0.)
tensor(7.)
tensor(0.)
tensor(7)
tensor(0)


In [152]:
import torch

a = torch.rand(2, 4)
print(a)

print(a.max(dim=1))
print(a.max(dim=1, keepdim=True))

tensor([[0.7239, 0.9412, 0.7602, 0.2131],
        [0.6277, 0.1033, 0.8300, 0.9909]])
torch.return_types.max(
values=tensor([0.9412, 0.9909]),
indices=tensor([1, 3]))
torch.return_types.max(
values=tensor([[0.9412],
        [0.9909]]),
indices=tensor([[1],
        [3]]))


## 10.3 topk

In [157]:
import torch

a = torch.rand(2, 4)
print(a)
print(a.topk(2, dim=1))
'''
tensor([[0.3247, 0.9220, 0.4314, 0.8123],
        [0.7133, 0.2471, 0.0281, 0.3595]])
torch.return_types.topk(
values=tensor([[0.9220, 0.8123],
        [0.7133, 0.3595]]),
indices=tensor([[1, 3],
        [0, 3]]))
'''

tensor([[0.3247, 0.9220, 0.4314, 0.8123],
        [0.7133, 0.2471, 0.0281, 0.3595]])
torch.return_types.topk(
values=tensor([[0.9220, 0.8123],
        [0.7133, 0.3595]]),
indices=tensor([[1, 3],
        [0, 3]]))


## 10.4 kthvalue

In [167]:
import torch

a = torch.rand(2, 4)
print(a)
print(a.kthvalue(3, dim=1))
'''
tensor([[0.0980, 0.0479, 0.9298, 0.5638],
        [0.9095, 0.9071, 0.4913, 0.6144]])
torch.return_types.kthvalue(
values=tensor([0.5638, 0.9071]),
indices=tensor([3, 1]))
'''

tensor([[0.0980, 0.0479, 0.9298, 0.5638],
        [0.9095, 0.9071, 0.4913, 0.6144]])
torch.return_types.kthvalue(
values=tensor([0.5638, 0.9071]),
indices=tensor([3, 1]))


## 10.5 比较运算函数

In [171]:
import torch

a = torch.rand(2, 3)
print(a)
'''
tensor([[0.1196, 0.5068, 0.9272],
        [0.6395, 0.2433, 0.9702]])
'''
# a >= 0.5
print(a.ge(0.5))
'''
tensor([[False,  True,  True],
        [ True, False,  True]])
'''
# a > 0.5
print(a.gt(0.5))
'''
tensor([[False,  True,  True],
        [ True, False,  True]])
'''
# a <= 0.5
print(a.le(0.5))
'''
tensor([[ True, False, False],
        [False,  True, False]])
'''
# a < 0.5
print(a.lt(0.5))
'''
tensor([[ True, False, False],
        [False,  True, False]])
'''
# a = 0.5
print(a.eq(0.5))
'''
tensor([[False, False, False],
        [False, False, False]])
'''

tensor([[0.1196, 0.5068, 0.9272],
        [0.6395, 0.2433, 0.9702]])
tensor([[False,  True,  True],
        [ True, False,  True]])
tensor([[False,  True,  True],
        [ True, False,  True]])
tensor([[ True, False, False],
        [False,  True, False]])
tensor([[ True, False, False],
        [False,  True, False]])
tensor([[False, False, False],
        [False, False, False]])


# 十一、高级操作

## 11.1 where

In [172]:
import torch

cond = torch.rand(2, 2)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(cond)
'''
tensor([[0.3622, 0.9658],
        [0.1774, 0.6670]])
'''
print(a)
'''
tensor([[0., 0.],
        [0., 0.]])
'''
print(b)
'''
tensor([[1., 1.],
        [1., 1.]])
'''
# 满足条件cond.ge(0.5)的按照a的对应元素赋值，否则按照b的对应元素赋值
print(torch.where(cond.ge(0.5), a, b))
'''
tensor([[1., 0.],
        [1., 0.]])
'''

tensor([[0.3622, 0.9658],
        [0.1774, 0.6670]])
tensor([[0., 0.],
        [0., 0.]])
tensor([[1., 1.],
        [1., 1.]])
tensor([[1., 0.],
        [1., 0.]])


## 11.2 gather

In [184]:
import torch

a = torch.arange(3, 12).view(3, 3)
print(a)
'''
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
'''

index = torch.tensor([[2, 1, 0]])
print(a.gather(1, index)) # tensor([[5, 4, 3]])

index = torch.tensor([[2, 1, 0]]).t()
print(a.gather(1, index))
'''
tensor([[5],
        [7],
        [9]])
'''

index = torch.tensor([[0, 2],
                      [1, 2]])
print(a.gather(1, index))
'''
tensor([[3, 5],
        [7, 8]])
'''

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
tensor([[5, 4, 3]])
tensor([[5],
        [7],
        [9]])
tensor([[3, 5],
        [7, 8]])
