In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 659.9446411132812
1 643.079833984375
2 626.7293701171875
3 610.8629150390625
4 595.486328125
5 580.5899047851562
6 566.1458740234375
7 552.1389770507812
8 538.5479736328125
9 525.37890625
10 512.5763549804688
11 500.09844970703125
12 487.9282531738281
13 476.062744140625
14 464.61932373046875
15 453.5344543457031
16 442.75335693359375
17 432.2699279785156
18 422.1927795410156
19 412.4147033691406
20 402.85882568359375
21 393.56439208984375
22 384.5235900878906
23 375.75927734375
24 367.2861328125
25 359.0279846191406
26 350.98785400390625
27 343.12152099609375
28 335.4375915527344
29 327.973876953125
30 320.7315979003906
31 313.6475830078125
32 306.7467346191406
33 299.9906921386719
34 293.4124755859375
35 286.9823303222656
36 280.6702880859375
37 274.45574951171875
38 268.3796691894531
39 262.4101257324219
40 256.5552062988281
41 250.82835388183594
42 245.2093963623047
43 239.68589782714844
44 234.26748657226562
45 228.9537353515625
46 223.7402801513672
47 218.61764526367188
48 213.

386 1.0221867796644801e-06
387 9.410269967702334e-07
388 8.668312716508808e-07
389 7.978343319337e-07
390 7.339559147112595e-07
391 6.753392085556698e-07
392 6.21314313775656e-07
393 5.712464599127998e-07
394 5.25133145856671e-07
395 4.826004555980035e-07
396 4.4358378659126174e-07
397 4.0749384311311587e-07
398 3.741901650755608e-07
399 3.4358612310825265e-07
400 3.15575107379118e-07
401 2.896519788464502e-07
402 2.657906748027017e-07
403 2.4377138174713764e-07
404 2.2360958951139764e-07
405 2.049890497346496e-07
406 1.8782942845518846e-07
407 1.720693774132087e-07
408 1.5758571692003898e-07
409 1.4427386929583008e-07
410 1.3220029870808503e-07
411 1.2105951441299112e-07
412 1.1084112827575154e-07
413 1.0134195349564834e-07
414 9.263317224394996e-08
415 8.470993151377115e-08
416 7.739018315078283e-08
417 7.07800751342802e-08
418 6.461105783728271e-08
419 5.894050403298934e-08
420 5.3990731885278365e-08
421 4.9150600034408853e-08
422 4.495553795891283e-08
423 4.096182593116282e-08
424 