In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.  
对比我们之前使用的更新权值的方式，我们使用 optim 包来定义一个优化器来更新权值参数。
这个 optim 包定义了许多在深度学习中经常用到的优化算法，如 SGD+Momentum，RMSProp，Adam 等等。



In [1]:
import torch
from torch.autograd import Variable

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Use the nn package to define our model and loss function.
# 定义一个model
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)  #定义使用的误差

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Variables it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  #定义优化器，向里面传入模型参数
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)  #计算输出

    # Compute and print loss.
    loss = loss_fn(y_pred, y)  #计算误差
    print(t, loss.data[0])

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable weights
    # of the model)
    optimizer.zero_grad()  #反向求导前梯度置零，直接使用优化器的置零函数即可

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()  #反向传播求梯度

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()  #更新参数

0 676.1935424804688
1 658.651611328125
2 641.6445922851562
3 625.1925048828125
4 609.2385864257812
5 593.6788330078125
6 578.6488037109375
7 563.9906616210938
8 549.7250366210938
9 535.8424682617188
10 522.3507080078125
11 509.150390625
12 496.35101318359375
13 484.02099609375
14 472.0038757324219
15 460.3139343261719
16 448.982666015625
17 437.9811096191406
18 427.26116943359375
19 416.8269958496094
20 406.67706298828125
21 396.81927490234375
22 387.2613220214844
23 377.9117126464844
24 368.77056884765625
25 359.8409729003906
26 351.1045837402344
27 342.5767517089844
28 334.227294921875
29 326.0708312988281
30 318.1455078125
31 310.4358215332031
32 302.903076171875
33 295.5453796386719
34 288.3844909667969
35 281.38848876953125
36 274.5393371582031
37 267.8331604003906
38 261.26776123046875
39 254.83103942871094
40 248.51101684570312
41 242.3188018798828
42 236.28076171875
43 230.39768981933594
44 224.65179443359375
45 219.0424346923828
46 213.5535125732422
47 208.1712646484375
48 202

368 0.00011218230793019757
369 0.00010631121404003352
370 0.00010074776218971238
371 9.545985813019797e-05
372 9.04514963622205e-05
373 8.569495548726991e-05
374 8.11876670923084e-05
375 7.690890197409317e-05
376 7.285003084689379e-05
377 6.900169682921842e-05
378 6.535491411341354e-05
379 6.189181294757873e-05
380 5.861499084858224e-05
381 5.55014303245116e-05
382 5.255380165181123e-05
383 4.9759353714762256e-05
384 4.71074235974811e-05
385 4.459477713680826e-05
386 4.221396375214681e-05
387 3.995646693510935e-05
388 3.7816880649188533e-05
389 3.5791446862276644e-05
390 3.386912067071535e-05
391 3.204796666977927e-05
392 3.0323895771289244e-05
393 2.8691190891549923e-05
394 2.7141279133502394e-05
395 2.5676712539279833e-05
396 2.4287450287374668e-05
397 2.2973157683736645e-05
398 2.1727286366513e-05
399 2.054661490547005e-05
400 1.9431034161243588e-05
401 1.8374150386080146e-05
402 1.737139609758742e-05
403 1.6423611668869853e-05
404 1.5525809430982918e-05
405 1.4678960724268109e-05
4