In [None]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(size_average=False)

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access and gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 607.8184204101562
1 562.2053833007812
2 522.6771850585938
3 488.0782775878906
4 457.2925109863281
5 429.5785827636719
6 404.5616149902344
7 381.7408447265625
8 360.7369689941406
9 341.3135681152344
10 323.2158508300781
11 306.1913757324219
12 290.18914794921875
13 275.162109375
14 260.8387451171875
15 247.28082275390625
16 234.37844848632812
17 222.14157104492188
18 210.466796875
19 199.36422729492188
20 188.74285888671875
21 178.58004760742188
22 168.86013793945312
23 159.60501098632812
24 150.78314208984375
25 142.3392333984375
26 134.2963104248047
27 126.6358871459961
28 119.34710693359375
29 112.40806579589844
30 105.82416534423828
31 99.57060241699219
32 93.6412353515625
33 88.03064727783203
34 82.71832275390625
35 77.7052001953125
36 72.97801971435547
37 68.5078353881836
38 64.28417205810547
39 60.30238723754883
40 56.531944274902344
41 52.98442459106445
42 49.64726638793945
43 46.511165618896484
44 43.55995559692383
45 40.78837203979492
46 38.181522369384766
47 35.739044189453

351 5.856501593370922e-05
352 5.65330519748386e-05
353 5.457560109789483e-05
354 5.2682833484141156e-05
355 5.086211604066193e-05
356 4.910432471660897e-05
357 4.7404257202288136e-05
358 4.576425635605119e-05
359 4.418281969265081e-05
360 4.265847019269131e-05
361 4.118378274142742e-05
362 3.976379593950696e-05
363 3.839366036118008e-05
364 3.707038558786735e-05
365 3.5792159906122833e-05
366 3.455995101830922e-05
367 3.337138696224429e-05
368 3.2222313166130334e-05
369 3.11131079797633e-05
370 3.0045219318708405e-05
371 2.901433254010044e-05
372 2.8017939257551916e-05
373 2.7056399630964734e-05
374 2.6128154786420055e-05
375 2.5234254280803725e-05
376 2.4369615857722238e-05
377 2.3531099941465072e-05
378 2.272600613650866e-05
379 2.1947889763396233e-05
380 2.1198355170781724e-05
381 2.0473393306019716e-05
382 1.977404099307023e-05
383 1.9096250980510376e-05
384 1.8445531168254092e-05
385 1.7815607861848548e-05
386 1.7208292774739675e-05
387 1.662128488533199e-05
388 1.6053449144237675