In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access and gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 633.8336791992188
1 589.7041625976562
2 551.2622680664062
3 517.116943359375
4 486.2258605957031
5 458.6175231933594
6 433.94647216796875
7 411.2509460449219
8 390.02301025390625
9 370.27117919921875
10 351.6968688964844
11 334.1072998046875
12 317.4642639160156
13 301.63592529296875
14 286.56903076171875
15 272.20367431640625
16 258.4706115722656
17 245.37796020507812
18 232.88429260253906
19 220.93081665039062
20 209.51284790039062
21 198.60513305664062
22 188.1500244140625
23 178.2006378173828
24 168.7314910888672
25 159.699951171875
26 151.06309509277344
27 142.8150177001953
28 134.97702026367188
29 127.5297622680664
30 120.45606231689453
31 113.74514770507812
32 107.3612060546875
33 101.31586456298828
34 95.5934829711914
35 90.17866516113281
36 85.06742858886719
37 80.2416000366211
38 75.68746948242188
39 71.37828063964844
40 67.32025146484375
41 63.50246047973633
42 59.90629196166992
43 56.51531982421875
44 53.322628021240234
45 50.31420135498047
46 47.476539611816406
47 44.804

442 1.8288179489900358e-05
443 1.77577185240807e-05
444 1.724223511700984e-05
445 1.6740661521907896e-05
446 1.6255049558822066e-05
447 1.5783789422130212e-05
448 1.5326373613788746e-05
449 1.4881195056659635e-05
450 1.4450952221523039e-05
451 1.4031737919140141e-05
452 1.3625885912915692e-05
453 1.3231367120170034e-05
454 1.2848758160544094e-05
455 1.2477651580411475e-05
456 1.211639118992025e-05
457 1.176718797069043e-05
458 1.1426953278714791e-05
459 1.1097610695287585e-05
460 1.077760498446878e-05
461 1.046656325343065e-05
462 1.0164394552703016e-05
463 9.871752808976453e-06
464 9.587321983417496e-06
465 9.312365364166908e-06
466 9.04391890799161e-06
467 8.783546036283951e-06
468 8.530721970601007e-06
469 8.285817784781102e-06
470 8.04735282144975e-06
471 7.816941433702596e-06
472 7.592598649353022e-06
473 7.374027063633548e-06
474 7.161572739278199e-06
475 6.956786819500849e-06
476 6.757627943443367e-06
477 6.5639255808491725e-06
478 6.375421889970312e-06
479 6.193367426021723e-06