In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.  
这个实现使用 PyTorch 的 nn 包来建立网络。PyTorch 的自动求导机制另定义一个计算图和计算梯度变得容易，但是原始的自动求导可能对于定义一个复杂的神经网络来说可能会有点过于底层；这就是 nn 包的价值所在。nn 包定义了一个模型的集合，你可以认为是一个可以从输入计算输出的拥有可训练参数的神经网络层。



In [3]:
import torch
from torch.autograd import Variable

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

'''
Use the nn package to define our model as a sequence of layers. nn.Sequential
is a Module which contains other Modules, and applies them in sequence to
produce its output. Each Linear Module computes output from input using a
linear function, and holds internal Variables for its weight and bias.  
使用nn包来定义一系列连续层的模型。nn.Sequential 是一个包含其它模块的模块，并利用这些模块按顺序排列来计算输出。
每个线性模块使用线性函数根据输入计算输出，并且会保存内部的 Variable 的权重和偏置参数。
'''

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Variable of input data to the Module and it produces
    # a Variable of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Variables containing the predicted and true
    # values of y, and the loss function returns a Variable containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Variables with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Variable, so
    # we can access its data and gradients like we did before.
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data

0 654.0662231445312
1 604.6380004882812
2 562.0304565429688
3 524.3499145507812
4 490.8837890625
5 460.7255554199219
6 433.1714172363281
7 408.03472900390625
8 384.84149169921875
9 363.64959716796875
10 344.0705871582031
11 325.87738037109375
12 308.7237854003906
13 292.5126037597656
14 277.1753845214844
15 262.6740417480469
16 248.90330505371094
17 235.79222106933594
18 223.3400421142578
19 211.54122924804688
20 200.3196563720703
21 189.63558959960938
22 179.51756286621094
23 169.89222717285156
24 160.76419067382812
25 152.05630493164062
26 143.7950897216797
27 135.95086669921875
28 128.51376342773438
29 121.44344329833984
30 114.74948120117188
31 108.41526794433594
32 102.3966293334961
33 96.66917419433594
34 91.25836944580078
35 86.14564514160156
36 81.30431365966797
37 76.72748565673828
38 72.41210174560547
39 68.33642578125
40 64.48467254638672
41 60.86014175415039
42 57.44373321533203
43 54.221351623535156
44 51.18336868286133
45 48.3228759765625
46 45.62440872192383
47 43.078601

460 3.398528497200459e-05
461 3.313972047180869e-05
462 3.231535811210051e-05
463 3.1509149266639724e-05
464 3.072643085033633e-05
465 2.9964314308017492e-05
466 2.921777922892943e-05
467 2.849215889000334e-05
468 2.7784517442341894e-05
469 2.7095522455056198e-05
470 2.6422523660585284e-05
471 2.5767487386474386e-05
472 2.5130144422291778e-05
473 2.4506012778147124e-05
474 2.3898741346783936e-05
475 2.33063219639007e-05
476 2.272993151564151e-05
477 2.2168249415699393e-05
478 2.1620377083308995e-05
479 2.108456646965351e-05
480 2.0563918951665983e-05
481 2.005406167882029e-05
482 1.9559136489988305e-05
483 1.9075978343607858e-05
484 1.8603750504553318e-05
485 1.8145043213735335e-05
486 1.7695590941002592e-05
487 1.7260039385291748e-05
488 1.6832891560625285e-05
489 1.6419076928286813e-05
490 1.601271469553467e-05
491 1.5619039913872257e-05
492 1.5234511920425575e-05
493 1.4859773727948777e-05
494 1.4492913578578737e-05
495 1.4135296623862814e-05
496 1.3787103853246663e-05
497 1.3448096