In [1]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [3]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
#     torch.nn.ReLU(),
#     torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss()

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 1.0137114524841309
1 1.0135858058929443
2 1.013458251953125
3 1.0133311748504639
4 1.0132057666778564
5 1.0130784511566162
6 1.0129516124725342
7 1.0128253698349
8 1.0126993656158447
9 1.012573003768921
10 1.0124465227127075
11 1.0123196840286255
12 1.0121941566467285
13 1.0120679140090942
14 1.0119407176971436
15 1.0118153095245361
16 1.011688470840454
17 1.0115633010864258
18 1.0114367008209229
19 1.011310338973999
20 1.0111844539642334
21 1.0110580921173096
22 1.0109338760375977
23 1.0108063220977783
24 1.010681390762329
25 1.0105546712875366
26 1.0104286670684814
27 1.0103027820587158
28 1.010176658630371
29 1.0100510120391846
30 1.0099258422851562
31 1.0098000764846802
32 1.009673833847046
33 1.009547472000122
34 1.009421706199646
35 1.009296178817749
36 1.0091707706451416
37 1.0090446472167969
38 1.0089197158813477
39 1.0087929964065552
40 1.0086677074432373
41 1.0085426568984985
42 1.0084171295166016
43 1.0082911252975464
44 1.0081666707992554
45 1.0080410242080688
46 1.007915

456 0.9588869214057922
457 0.9587734937667847
458 0.9586593508720398
459 0.9585461616516113
460 0.9584329724311829
461 0.9583184123039246
462 0.9582053422927856
463 0.9580910801887512
464 0.9579782485961914
465 0.9578644633293152
466 0.9577502012252808
467 0.9576373100280762
468 0.9575239419937134
469 0.957409679889679
470 0.957297146320343
471 0.957183837890625
472 0.9570702314376831
473 0.9569562673568726
474 0.9568436741828918
475 0.9567298889160156
476 0.956616222858429
477 0.9565026164054871
478 0.9563896059989929
479 0.9562757611274719
480 0.9561635255813599
481 0.9560497999191284
482 0.9559370279312134
483 0.9558240175247192
484 0.9557105302810669
485 0.9555975794792175
486 0.9554837346076965
487 0.9553712606430054
488 0.9552586674690247
489 0.955145537853241
490 0.955032229423523
491 0.954919159412384
492 0.9548062086105347
493 0.9546929597854614
494 0.9545804858207703
495 0.9544674754142761
496 0.9543545842170715
497 0.9542415738105774
498 0.9541285634040833
499 0.954016327857