In [1]:
import torch
from torch.autograd import Variable

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(size_average=False)

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Variables it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.data[0])

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable weights
    # of the model)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 713.34765625
1 695.8994140625
2 678.9251708984375
3 662.3817138671875
4 646.3189086914062
5 630.7846069335938
6 615.7015991210938
7 600.9579467773438
8 586.571533203125
9 572.613525390625
10 559.1349487304688
11 546.0465087890625
12 533.2824096679688
13 520.9108276367188
14 508.8627624511719
15 497.10052490234375
16 485.6846618652344
17 474.5831604003906
18 463.8304138183594
19 453.35589599609375
20 443.0989990234375
21 433.1167907714844
22 423.350341796875
23 413.7862854003906
24 404.4486389160156
25 395.3368225097656
26 386.4234619140625
27 377.71575927734375
28 369.214111328125
29 360.92108154296875
30 352.79754638671875
31 344.8386535644531
32 337.09710693359375
33 329.539306640625
34 322.1540222167969
35 314.9305114746094
36 307.8539733886719
37 300.9238586425781
38 294.19158935546875
39 287.59210205078125
40 281.1143493652344
41 274.75762939453125
42 268.5172119140625
43 262.3835144042969
44 256.3536071777344
45 250.43722534179688
46 244.65541076660156
47 238.98086547851562
48 

423 0.00013494286395143718
424 0.00012981089821550995
425 0.00012486992636695504
426 0.0001201107443193905
427 0.00011552937212400138
428 0.00011111426283605397
429 0.00010686700261430815
430 0.00010277274850523099
431 9.883518214337528e-05
432 9.50380417634733e-05
433 9.13834446691908e-05
434 8.786536636762321e-05
435 8.448119478998706e-05
436 8.122088911477476e-05
437 7.808418740751222e-05
438 7.506184920202941e-05
439 7.215235382318497e-05
440 6.935308192623779e-05
441 6.665792898274958e-05
442 6.406417378457263e-05
443 6.157254392746836e-05
444 5.916814916417934e-05
445 5.6854038120945916e-05
446 5.463030902319588e-05
447 5.249132664175704e-05
448 5.043237979407422e-05
449 4.845086732530035e-05
450 4.654514850699343e-05
451 4.471241118153557e-05
452 4.2945459426846355e-05
453 4.1248677007388324e-05
454 3.96182804252021e-05
455 3.8050133298384026e-05
456 3.65393643733114e-05
457 3.508843656163663e-05
458 3.369007390574552e-05
459 3.234956238884479e-05
460 3.105778159806505e-05
461 2