# 单层前向神经网络例子

In [2]:
import torch

In [16]:
# 样本数，样本维度，隐藏层维度，输出层维度
N, D_in, H, D_out = 64, 1000, 100, 10 

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H, requires_grad = True)
w2 = torch.randn(H, D_out, requires_grad = True)

learning_rate = 1e-6
for i in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    loss = (y_pred - y).pow(2).sum()
    if i % 20 == 0:
        print(i, loss.item())
    loss.backward()
    
    with torch.no_grad():
        w1 -= w1.grad*learning_rate
        w2 -= w2.grad*learning_rate
        # grad清零  因为已经为W权重重新赋值了。
        # 否则下次反向传播时又会带上上一次的grad结果
        w1.grad.zero_()
        w2.grad.zero_()
    

0 29354564.0
20 272370.1875
40 38350.5390625
60 8370.1767578125
80 2262.49853515625
100 694.3679809570312
120 230.3604736328125
140 80.29074096679688
160 28.973861694335938
180 10.727495193481445
200 4.051109313964844
220 1.5543962717056274
240 0.6047776341438293


260 0.23779287934303284
280 0.09440027177333832
300 0.037795063108205795
320 0.01533600315451622
340 0.006372262258082628
360 0.002778454218059778
380 0.0013103475794196129
400 0.0006788184982724488
420 0.0003836988762486726
440 0.00023478207003790885
460 0.00015404449368361384
480

 0.00010569652658887208


# 使用nn包构建前向网络

In [17]:
import torch.nn as nn


In [21]:
# 样本数，样本维度，隐藏层维度，输出层维度
N, D_in, H, D_out = 64, 1000, 100, 10 

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = nn.Sequential(
    nn.Linear(D_in, H),
    nn.ReLU(),
    nn.Linear(H, D_out)
)
loss_fn = nn.MSELoss(reduction="sum")

learning_rate = 1e-4
for i in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    
    if i % 20 == 0:
        print(i, loss.item())
        
    model.zero_grad()
    
    loss.backward()
    
    with torch.no_grad():
        for para in model.parameters():
            para -= para.grad * learning_rate


0 644.4891967773438
20 176.6575927734375
40 48.838409423828125
60 13.813834190368652
80 4.492364883422852
100 1.6898926496505737
120 0.7162458300590515
140 0.3351461589336395
160 0.1701829731464386
180 0.09293076395988464
200

 0.05348130315542221
220 0.03208024799823761
240 0.0198835376650095
260 0.01264018565416336
280 0.008199227973818779
300 0.005406938958913088
320 0.003610269632190466
340 0.0024333088658750057
360 0.001651531900279224
380

 0.0011269977549090981
400 0.0007722655427642167
420 0.0005309298285283148
440 0.0003659167850855738
460 0.000252686848398298
480 0.00017476141510996968
