PyTorch: nn

A fully-connected ReLU network with one hidden layer, trained to predict y from x by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network. PyTorch autograd makes it easy to define computational graphs and take gradients, but raw autograd can be a bit too low-level for defining complex neural networks; this is where the nn package can help. The nn package defines a set of Modules, which you can think of as a neural network layer that has produces output from input and may have some trainable weights.
source:https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

In [1]:
# -*- coding: utf-8 -*-
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 722.53564453125
1 664.0517578125
2 613.7988891601562
3 570.0084838867188
4 531.0029296875
5 495.97418212890625
6 464.13482666015625
7 435.1578063964844
8 408.4244689941406
9 383.6463317871094
10 360.5606384277344
11 339.0355224609375
12 318.8609924316406
13 299.897216796875
14 282.0491638183594
15 265.18658447265625
16 249.25808715820312
17 234.18453979492188
18 219.9244384765625
19 206.46539306640625
20 193.77316284179688
21 181.81448364257812
22 170.5233154296875
23 159.8327178955078
24 149.74295043945312
25 140.27542114257812
26 131.4088897705078
27 123.07568359375
28 115.2274169921875
29 107.86290740966797
30 100.96308898925781
31 94.50130462646484
32 88.45410919189453
33 82.80554962158203
34 77.52965545654297
35 72.60050964355469
36 67.99993133544922
37 63.693824768066406
38 59.654598236083984
39 55.889564514160156
40 52.37618637084961
41 49.09635543823242
42 46.031890869140625
43 43.16814422607422
44 40.493412017822266
45 37.99457931518555
46 35.64067459106445
47 33.44218444824

354 6.63053579046391e-05
355 6.407646287698299e-05
356 6.192243745317683e-05
357 5.984721428831108e-05
358 5.78443068661727e-05
359 5.590790169662796e-05
360 5.4038362577557564e-05
361 5.222820254857652e-05
362 5.0483951781643555e-05
363 4.8800913646118715e-05
364 4.7173780330922455e-05
365 4.560259549180046e-05
366 4.408174936543219e-05
367 4.2614483390934765e-05
368 4.120100129512139e-05
369 3.983226270065643e-05
370 3.8509740988956764e-05
371 3.723286499734968e-05
372 3.600196578190662e-05
373 3.4807660995284095e-05
374 3.365829979884438e-05
375 3.2544583518756554e-05
376 3.147041570628062e-05
377 3.0430812330450863e-05
378 2.9430048016365618e-05
379 2.845847120624967e-05
380 2.75207403319655e-05
381 2.6615291062626056e-05
382 2.5741806894075125e-05
383 2.489608777977992e-05
384 2.4081529772956856e-05
385 2.3289589080377482e-05
386 2.2528678528033197e-05
387 2.1792731786263175e-05
388 2.1078420104458928e-05
389 2.0392104488564655e-05
390 1.9724267986021005e-05
391 1.907992918859236e