1. 准备数据集

In [1]:
import torch
import d2l.torch as d2l
true_w = torch.tensor([2.0, -3.0])
true_b = torch.tensor([4.2])
x_data, y_data = d2l.synthetic_data(true_w, true_b, 1000) # 生成1000个样本(x,y)
x_data.shape, y_data.shape

(torch.Size([1000, 2]), torch.Size([1000, 1]))

2. 用类设计模型

In [8]:
class LinearModel(torch.nn.Module): # 任何一个层和一个神经网络都是Module的一个子类
    """
    必须要有__init__ 和 forward
    但是没有backforward，因为Module里面会自动求出来
    """
    def __init__(self) -> None: 
        super(LinearModel, self).__init__() # 将父类的__init__()方法放到自己的__init__()中
        self.linear = torch.nn.Linear(in_features=2, out_features=1) # 输入维度/特征数目为2
        # Linear是torch中的一个类，继承自Module；构造对象；作用是生成wx+b
        ## nn.Linear 为torch预定好的线性模型，也称作全连接层
        
    def forward(self, x):
        y_pred = self.linear(x) # linear是一个可调用对象
        # nn.Linear继承自nn.Module, Module定义了__call__方法，其实现为调用forward方法
        return y_pred

model = LinearModel()  # 实例化; model也是callable; 

"""
model(X) 该操作调用的是函数model.__call__()，nn.Module将该函数等价于forward()函数，故可实现前向传播，等价于model.forward(X)
"""

3. 构件损失loss和优化器optimizer

In [9]:
# 构建一个MSE尺度优化器
criterion = torch.nn.MSELoss(size_average=False) # 不求均值
# 构建一个SGD优化器对象，他知道对哪些权重做优化，学习率是多少
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) 


4. 训练

In [10]:
# 训练过程
for epoch in range(1000):
    # Forward
    y_pred = model(x_data) # 前向传播
    loss = criterion(y_pred, y_data) # 计算损失
    print(epoch, loss.item()) # loss打印是会自动调用__str__()

    #Backward
    optimizer.zero_grad() # 由反向传播计算grad需要梯度归零，否则会梯度累计
    loss.backward() # 反向传播
    
    # Update
    optimizer.step() # 更新

# 输出权重和偏置
print('w= ', model.linear.weight)
print('b= ', model.linear.bias.item())

# 测试模型
x_test = torch.Tensor([[4.0, 3]])
y_test = model(x_test)
print('y_pred= ', y_test.data)


0 33079.05859375
1 20819.166015625
2 13109.1484375
3 8258.185546875
4 5204.65673828125
5 3281.67041015625
6 2070.10302734375
7 1306.417236328125
8 824.8302612304688
9 521.0039672851562
10 329.24163818359375
11 208.15777587890625
12 131.6702117919922
13 83.33383178710938
14 52.775177001953125
15 33.44806671142578
16 21.219770431518555
17 13.479883193969727
18 8.579144477844238
19 5.474952220916748
20 3.507995128631592
21 2.2612085342407227
22 1.4706196784973145
23 0.9691581726074219
24 0.6509830951690674
25 0.4490317106246948
26 0.3208182454109192
27 0.2393970936536789
28 0.18767374753952026
29 0.154804527759552
30 0.1339121311903
31 0.12062662094831467
32 0.11217693984508514
33 0.10680210590362549
34 0.10338414460420609
35 0.10120721161365509
36 0.09982220828533173
37 0.09893986582756042
38 0.0983777791261673
39 0.0980205312371254
40 0.09779240190982819
41 0.09764758497476578
42 0.09755461663007736
43 0.09749579429626465
44 0.09745824337005615
45 0.09743440896272659
46 0.09741916507482