In [None]:
%matplotlib inline


PyTorch: optim
--------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.

Rather than manually updating the weights of the model as we have been doing,
we use the optim package to define an Optimizer that will update the weights
for us. The optim package defines many optimization algorithms that are commonly
used for deep learning, including SGD+momentum, RMSProp, Adam, etc.



In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
# nn.Sequential is a Module which contains other Modules, and applies them in sequence to produce its output.
# Each Linear Module computes output from input using a linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this case we will use MSE as loss function
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of the model for us. Here we will use Adam;
# the optim package contains many other optimization algorithms. The first argument to the Adam constructor tells the 
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(500):
    # Forward pass
    y_pred = model(x)
    
    # Compute and print loss
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    # Zero the gradients before running the backward pass
    model.zero_grad()
    
    # Backward pass
    loss.backward()
    
    # Calling the step function on an Optimizer makes and update to its parameters
    optimizer.step()

0 668.2874145507812
1 651.3516235351562
2 634.93603515625
3 618.9886474609375
4 603.5103149414062
5 588.5125122070312
6 574.0078735351562
7 560.0535278320312
8 546.5635375976562
9 533.6114501953125
10 521.0109252929688
11 508.7533874511719
12 496.869140625
13 485.33660888671875
14 474.130615234375
15 463.19610595703125
16 452.5824279785156
17 442.29461669921875
18 432.24371337890625
19 422.4303283691406
20 412.8651428222656
21 403.57415771484375
22 394.53753662109375
23 385.7462158203125
24 377.168212890625
25 368.8588562011719
26 360.7206726074219
27 352.75946044921875
28 344.9613037109375
29 337.3985290527344
30 330.0144958496094
31 322.7799987792969
32 315.6879577636719
33 308.74859619140625
34 301.9581298828125
35 295.3166198730469
36 288.8456115722656
37 282.53167724609375
38 276.32000732421875
39 270.2302551269531
40 264.273681640625
41 258.4519348144531
42 252.7480926513672
43 247.15467834472656
44 241.67562866210938
45 236.3123016357422
46 231.05999755859375
47 225.914321899414

412 0.00022287586762104183
413 0.00021492222731467336
414 0.00020724249770864844
415 0.0001998257648665458
416 0.00019266318122390658
417 0.00018574799469206482
418 0.0001790695678209886
419 0.00017262226901948452
420 0.0001663975272094831
421 0.00016038618923630565
422 0.00015458637790288776
423 0.00014898668450769037
424 0.00014357944019138813
425 0.00013835993013344705
426 0.00013332499656826258
427 0.00012846261961385608
428 0.00012377265375107527
429 0.00011924491263926029
430 0.00011487580923130736
431 0.00011066246952395886
432 0.00010659592953743413
433 0.00010267324978485703
434 9.889101784210652e-05
435 9.523727203486487e-05
436 9.171483543468639e-05
437 8.831632294459268e-05
438 8.503918070346117e-05
439 8.188186620827764e-05
440 7.883077341830358e-05
441 7.589475717395544e-05
442 7.305986218852922e-05
443 7.032795838313177e-05
444 6.769032188458368e-05
445 6.515228596981615e-05
446 6.270472658798099e-05
447 6.034463149262592e-05
448 5.8069901569979265e-05
449 5.587740815826