In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 638.73681640625
1 590.0200805664062
2 548.521240234375
3 512.2881469726562
4 480.1826477050781
5 451.4676513671875
6 425.4517517089844
7 401.51812744140625
8 379.4933776855469
9 359.1138610839844
10 340.2139587402344
11 322.6264343261719
12 306.140625
13 290.6034240722656
14 275.8786315917969
15 261.8199768066406
16 248.44471740722656
17 235.66529846191406
18 223.45960998535156
19 211.853759765625
20 200.77011108398438
21 190.22412109375
22 180.16567993164062
23 170.5894012451172
24 161.46551513671875
25 152.8158721923828
26 144.58770751953125
27 136.75904846191406
28 129.28982543945312
29 122.15337371826172
30 115.36068725585938
31 108.91992950439453
32 102.8168716430664
33 97.03398132324219
34 91.546142578125
35 86.36146545410156
36 81.45854187011719
37 76.82139587402344
38 72.44647979736328
39 68.30876159667969
40 64.40553283691406
41 60.7252197265625
42 57.249961853027344
43 53.971351623535156
44 50.883056640625
45 47.970947265625
46 45.22929000854492
47 42.65199279785156
48 40.2

451 0.00010717331315390766
452 0.00010461594501975924
453 0.00010211743938270956
454 9.968122321879491e-05
455 9.730372403282672e-05
456 9.498513099970296e-05
457 9.272102761315182e-05
458 9.051265806192532e-05
459 8.835826884023845e-05
460 8.625619375379756e-05
461 8.420453377766535e-05
462 8.22001020424068e-05
463 8.024642738746479e-05
464 7.83402647357434e-05
465 7.647828897461295e-05
466 7.466244278475642e-05
467 7.288872438948601e-05
468 7.115786138456315e-05
469 6.947155634406954e-05
470 6.782282434869558e-05
471 6.621682405238971e-05
472 6.464844045694917e-05
473 6.311660399660468e-05
474 6.162118370411918e-05
475 6.0162736190250143e-05
476 5.8738183724926785e-05
477 5.7348457630723715e-05
478 5.599293581326492e-05
479 5.466971197165549e-05
480 5.337887705536559e-05
481 5.2117015002295375e-05
482 5.0886330427601933e-05
483 4.968500434188172e-05
484 4.85144009871874e-05
485 4.7368132072733715e-05
486 4.625079600373283e-05
487 4.516094486461952e-05
488 4.409491521073505e-05
489 4.