In [3]:
import torch

In [4]:


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 29560736.0
1 24667754.0
2 22743136.0
3 20562466.0
4 16967238.0
5 12517790.0
6 8359041.5
7 5286023.0
8 3312601.75
9 2146933.75
10 1471030.875
11 1073522.75
12 827642.625
13 665713.375
14 551708.0
15 466687.46875
16 400343.625
17 346841.15625
18 302733.21875
19 265745.21875
20 234292.171875
21 207436.71875
22 184350.90625
23 164388.546875
24 147032.609375
25 131886.796875
26 118614.21875
27 106968.515625
28 96693.640625
29 87601.171875
30 79527.9140625
31 72335.921875
32 65915.890625
33 60162.33203125
34 54998.85546875
35 50350.7109375
36 46157.51953125
37 42366.703125
38 38935.390625
39 35823.31640625
40 32996.0234375
41 30422.6640625
42 28078.466796875
43 25939.669921875
44 23985.46484375
45 22197.59765625
46 20559.615234375
47 19057.208984375
48 17677.73828125
49 16409.775390625
50 15243.48828125
51 14169.4013671875
52 13179.0517578125
53 12265.3935546875
54 11421.5966796875
55 10642.1455078125
56 9920.880859375
57 9254.5322265625
58 8638.0810546875
59 8066.8837890625
60 7537.109375

442 0.00022910477127879858
443 0.0002238870511064306
444 0.00021917522826697677
445 0.00021337081852834672
446 0.0002090591297019273
447 0.00020482145191635936
448 0.0002001166285481304
449 0.00019586950656957924
450 0.00019079685444012284
451 0.000186906720045954
452 0.0001826600928325206
453 0.00017923710402101278
454 0.00017570493218954653
455 0.00017184860189445317
456 0.00016791463713161647
457 0.00016438467719126493
458 0.00016139411309268326
459 0.00015760670066811144
460 0.00015454774256795645
461 0.00015104797785170376
462 0.00014827519771642983
463 0.0001451713324058801
464 0.0001421858323737979
465 0.00013947131810709834
466 0.00013655280054081231
467 0.00013374285481404513
468 0.0001314034452661872
469 0.0001289334031753242
470 0.00012644701928365976
471 0.00012397700629662722
472 0.00012162506027380005
473 0.00011936147348023951
474 0.0001172802658402361
475 0.00011525127774802968
476 0.00011303213250357658
477 0.000111094988824334
478 0.00010853115963982418
479 0.00010679