**pytorch实用技能**

@ Follow: "动手学深度学习-第二章 预备知识"

In [49]:
import numpy as np
import torch

---
# 基本数据结构-Tensor
- Tensor与numpy array非常相似, 但有2个新功能
- Tensor支持GPU, 支持自动微分

## 创建Tensor

In [46]:
x = torch.arange(12)  # 创建一个行向量
X_all0 = torch.zeros((2, 3))  # 创建一个2行3列的矩阵，元素全为0
X_all1 = torch.ones((2, 3))  # 创建一个2行3列的矩阵，元素全为1
X_rand = torch.randn(3, 2, 1)  # 创建一个2行3列的矩阵，元素为随机数

print(x)
print(X_all0)
print(X_all1)
print(X_rand)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[[ 0.3834],
         [-1.2739]],

        [[ 0.7454],
         [-0.5730]],

        [[-1.6398],
         [-0.2111]]])


In [52]:
ls = [[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]]
X = torch.tensor(ls)  # 通过列表创建张量

arr = np.array(ls)
X_arr = torch.from_numpy(arr)  # 通过numpy数组创建张量

print(X)
print(X_arr)

tensor([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]])
tensor([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]])


## 查看Tensor的属性

In [None]:
print(x.shape)  # 查看形状, 即沿每个轴的元素数量
print(x.numel())  # 查看元素数量

torch.Size([12])
12


## Tensor的操作与运算

### 更改张量形状

In [11]:
X = x.reshape(3, 4)  # 将行向量x的形状改为(3, 4), 即3行4列的矩阵
X1 = x.reshape(-1, 3)  # 指定形状为3列, 行数自动推断

print(X)
print(X1)

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])


### 按元素运算
- 按元素位置单独操作或运算

In [54]:
x = torch.tensor([1.0, 2, 4, 8])
y = torch.tensor([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y  # 加减乘除乘方

(tensor([ 3.,  4.,  6., 10.]),
 tensor([-1.,  0.,  2.,  6.]),
 tensor([ 2.,  4.,  8., 16.]),
 tensor([0.5000, 1.0000, 2.0000, 4.0000]),
 tensor([ 1.,  4., 16., 64.]))

In [55]:
torch.exp(x)  # 指数运算

tensor([2.7183e+00, 7.3891e+00, 5.4598e+01, 2.9810e+03])

广播机制: 即使两个张量的形状不同, 也可以执行元素操作

In [60]:
a = torch.arange(3).reshape((3, 1))
b = torch.arange(2).reshape((1, 2))
a, b

(tensor([[0],
         [1],
         [2]]),
 tensor([[0, 1]]))

In [61]:
a + b

tensor([[0, 1],
        [1, 2],
        [2, 3]])

### 张量连接

In [102]:
X = torch.arange(12, dtype=torch.float32).reshape((3, 4))
Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
print(X)
print(Y)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
tensor([[2., 1., 4., 3.],
        [1., 2., 3., 4.],
        [4., 3., 2., 1.]])


In [58]:
torch.cat((X, Y), dim=0)  # 沿行（轴0）拼接X和Y

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 2.,  1.,  4.,  3.],
        [ 1.,  2.,  3.,  4.],
        [ 4.,  3.,  2.,  1.]])

In [59]:
torch.cat((X, Y), dim=1)  # 沿列（轴1）拼接X和Y

tensor([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
        [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
        [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]])

## 节省内存

In [84]:
"""
直接赋值相同的变量名仍然分配了不同的内存
"""
X = torch.arange(12, dtype=torch.float32).reshape((3, 4))
before = id(X)
print(f"id(X): {before}")

X = X + 1
print(f"id(X): {id(X)}")  # id(X)与before不同

id(X): 4921618464
id(X): 4921198304


In [85]:
"""
使用切片表示法覆盖内存
"""
X = torch.arange(12, dtype=torch.float32).reshape((3, 4))
before = id(X)
print(f"id(X): {before}")

X[:] = X + 1
print(f"id(X): {id(X)}")

id(X): 4921571392
id(X): 4921571392


In [86]:
"""
使用 += 覆盖内存
"""
X = torch.arange(12, dtype=torch.float32).reshape((3, 4))
before = id(X)
print(f"id(X): {before}")

X += 1
print(f"id(X): {id(X)}")

id(X): 4921214128
id(X): 4921214128


## Tensor和Numpy array转换

In [101]:
A = torch.arange(12, dtype=torch.float32).reshape((3, 4))
B = A.numpy()
print(type(A), type(B))
print(id(A), id(B))  # 内存地址不同
print(A)
print(B)

# tensor和numpy数组共享内存, 修改一个会影响另一个(仅限就地操作, 如 +=, A[:])
B[:] = B + 1
print(A)
print(B)

<class 'torch.Tensor'> <class 'numpy.ndarray'>
4921968912 4914318672
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]]


---
# 数据预处理