In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.  
这个实现定义了一个自定义的模型。如果你想要一个比已存在模块更为复杂的模型，你需要通过这个方式定义你自己的模型。



In [1]:
import torch
from torch.autograd import Variable


class TwoLayerNet(torch.nn.Module):  #继承torch.nn.Module
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x): #只需要定义前向计算，反向计算会通过自动求导机制自动实现。
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 639.1766357421875
1 593.60986328125
2 554.0064086914062
3 518.9912719726562
4 487.8104553222656
5 459.76519775390625
6 434.1590881347656
7 410.4022216796875
8 388.2483825683594
9 367.6387634277344
10 348.5289001464844
11 330.6759338378906
12 313.9512634277344
13 298.21136474609375
14 283.1641540527344
15 268.90869140625
16 255.342529296875
17 242.35098266601562
18 229.9123992919922
19 218.02145385742188
20 206.65176391601562
21 195.7445831298828
22 185.31903076171875
23 175.36962890625
24 165.85816955566406
25 156.78968811035156
26 148.15066528320312
27 139.87130737304688
28 131.9964599609375
29 124.5095443725586
30 117.38334655761719
31 110.61993408203125
32 104.2024917602539
33 98.10753631591797
34 92.34053802490234
35 86.8803939819336
36 81.71894073486328
37 76.84841918945312
38 72.25601196289062
39 67.91532135009766
40 63.815155029296875
41 59.9547004699707
42 56.31075668334961
43 52.87302017211914
44 49.64067459106445
45 46.60267639160156
46 43.74545669555664
47 41.0632171630859

396 4.9204918468603864e-05
397 4.774182889377698e-05
398 4.632453419617377e-05
399 4.4951179006602615e-05
400 4.361612809589133e-05
401 4.2325576941948384e-05
402 4.106961205252446e-05
403 3.9852042391430587e-05
404 3.8672118535032496e-05
405 3.7525820516748354e-05
406 3.641564762801863e-05
407 3.5339569876668975e-05
408 3.429291245993227e-05
409 3.3278523915214464e-05
410 3.2295491109834984e-05
411 3.1343643058789894e-05
412 3.0418997994274832e-05
413 2.9517213988583535e-05
414 2.8646978535107337e-05
415 2.78025272564264e-05
416 2.6982099370798096e-05
417 2.6187206458416767e-05
418 2.5414108677068725e-05
419 2.466560727043543e-05
420 2.3939246602822095e-05
421 2.3235588741954416e-05
422 2.2553163944394328e-05
423 2.1890507923671976e-05
424 2.124492675648071e-05
425 2.0621799194486812e-05
426 2.0013447283417918e-05
427 1.9426688595558517e-05
428 1.885684105218388e-05
429 1.8301941963727586e-05
430 1.776375029294286e-05
431 1.7244499758817255e-05
432 1.6738869817345403e-05
433 1.6247809