In [1]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 629.5075073242188
1 612.67822265625
2 596.3663940429688
3 580.493896484375
4 565.1577758789062
5 550.248046875
6 535.7717895507812
7 521.7114868164062
8 508.1180419921875
9 494.904296875
10 482.14398193359375
11 469.78814697265625
12 457.8339538574219
13 446.2250061035156
14 434.942626953125
15 424.028076171875
16 413.4199523925781
17 403.1151123046875
18 393.1701965332031
19 383.5664978027344
20 374.2372741699219
21 365.19378662109375
22 356.3787841796875
23 347.7757263183594
24 339.38909912109375
25 331.21295166015625
26 323.2369384765625
27 315.4633483886719
28 307.9411315917969
29 300.5807800292969
30 293.3706970214844
31 286.36956787109375
32 279.6441955566406
33 273.0813293457031
34 266.6577453613281
35 260.3767395019531
36 254.22909545898438
37 248.2185821533203
38 242.34034729003906
39 236.58888244628906
40 230.96273803710938
41 225.45364379882812
42 220.0692901611328
43 214.80337524414062
44 209.6564483642578
45 204.5985107421875
46 199.6671905517578
47 194.8323974609375
48 

423 6.894770763210545e-07
424 6.403686256817309e-07
425 5.945742600488302e-07
426 5.523767754311848e-07
427 5.129215878696414e-07
428 4.7628316224290757e-07
429 4.4213567207407323e-07
430 4.103050059711677e-07
431 3.8099548760328616e-07
432 3.533941139721719e-07
433 3.2794127946544904e-07
434 3.043293759219523e-07
435 2.822219471454446e-07
436 2.6180799750363803e-07
437 2.426910157282691e-07
438 2.2500432805827586e-07
439 2.086441241999637e-07
440 1.9333748468852718e-07
441 1.7915203898155596e-07
442 1.6602295715983928e-07
443 1.5382273943487235e-07
444 1.4238281664802344e-07
445 1.3186974001655472e-07
446 1.2211593514166452e-07
447 1.1307682967753863e-07
448 1.0466260391694959e-07
449 9.688307045507827e-08
450 8.96468037581144e-08
451 8.300337839273197e-08
452 7.672222324117683e-08
453 7.10691239191874e-08
454 6.568348709379279e-08
455 6.077902980905492e-08
456 5.6140869730825216e-08
457 5.1913957577198744e-08
458 4.794602403990211e-08
459 4.435236178323976e-08
460 4.0956130931135704e