In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 665.5523071289062
1 648.934814453125
2 632.8447875976562
3 617.218994140625
4 602.0975341796875
5 587.3609008789062
6 573.0096435546875
7 559.0433349609375
8 545.4344482421875
9 532.2589111328125
10 519.4659423828125
11 507.08160400390625
12 495.0038146972656
13 483.202392578125
14 471.71502685546875
15 460.5299072265625
16 449.61199951171875
17 438.98956298828125
18 428.6252136230469
19 418.607666015625
20 408.89739990234375
21 399.4598693847656
22 390.2684631347656
23 381.2797546386719
24 372.5262451171875
25 363.9903869628906
26 355.6448974609375
27 347.5204772949219
28 339.6602783203125
29 331.96484375
30 324.4687805175781
31 317.1211242675781
32 309.96820068359375
33 303.02349853515625
34 296.23748779296875
35 289.58514404296875
36 283.06231689453125
37 276.6483459472656
38 270.37615966796875
39 264.2242126464844
40 258.22613525390625
41 252.37753295898438
42 246.6571044921875
43 241.04368591308594
44 235.5445556640625
45 230.18002319335938
46 224.95123291015625
47 219.825805664

462 2.5675097603539143e-08
463 2.369850271577434e-08
464 2.192421710844883e-08
465 2.0251741617016705e-08
466 1.8728147921365235e-08
467 1.731308785224428e-08
468 1.596675325288288e-08
469 1.4777982393354705e-08
470 1.3632377005023955e-08
471 1.2622883183155409e-08
472 1.1660506338273535e-08
473 1.0789769078201061e-08
474 9.975986259291858e-09
475 9.225444408400563e-09
476 8.546588325941684e-09
477 7.89668774814345e-09
478 7.301280025018286e-09
479 6.754469872305435e-09
480 6.239982752731521e-09
481 5.783970191686194e-09
482 5.3477324790662806e-09
483 4.9582666861169855e-09
484 4.604148617914916e-09
485 4.255103824846174e-09
486 3.960134442593244e-09
487 3.6559206773034703e-09
488 3.404293513753487e-09
489 3.1548392787073e-09
490 2.9164211046150967e-09
491 2.709515944943064e-09
492 2.5118473967467025e-09
493 2.3377419999803806e-09
494 2.173790925041885e-09
495 2.022987333205606e-09
496 1.8836645576669753e-09
497 1.7486689873891237e-09
498 1.6257779567041553e-09
499 1.5165992905963321e-