In [1]:
%matplotlib inline


PyTorch: Tensors and autograd
-------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.


A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
holding the gradient of ``x`` with respect to some scalar value.



In [2]:
import torch

dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 27565776.0
1 23296002.0
2 22137930.0
3 21047470.0
4 18665428.0
5 14753441.0
6 10461987.0
7 6796198.5
8 4264680.5
9 2691870.0
10 1773938.25
11 1238612.375
12 918028.4375
13 715358.875
14 578758.5
15 480871.75
16 406825.9375
17 348511.53125
18 301212.6875
19 262094.15625
20 229176.984375
21 201223.328125
22 177303.046875
23 156741.140625
24 138948.53125
25 123507.21875
26 110051.421875
27 98274.984375
28 87934.359375
29 78827.5625
30 70792.3359375
31 63684.03515625
32 57383.453125
33 51786.38671875
34 46803.51953125
35 42361.15234375
36 38390.03125
37 34835.1640625
38 31649.2578125
39 28788.23828125
40 26213.994140625
41 23895.884765625
42 21805.267578125
43 19919.162109375
44 18216.5859375
45 16674.19140625
46 15275.4716796875
47 14006.345703125
48 12853.3876953125
49 11804.59765625
50 10849.669921875
51 9979.0849609375
52 9184.4453125
53 8458.697265625
54 7795.19580078125
55 7188.6767578125
56 6633.47119140625
57 6124.61669921875
58 5657.951171875
59 5229.66357421875
60 4836.40185546

396 9.923791367327794e-05
397 9.72749330685474e-05
398 9.501873864792287e-05
399 9.271435556001961e-05
400 9.093934204429388e-05
401 8.910265751183033e-05
402 8.761539356783032e-05
403 8.577495464123785e-05
404 8.393446478294209e-05
405 8.245249046012759e-05
406 8.099155820673332e-05
407 7.915400783531368e-05
408 7.782346801832318e-05
409 7.633116911165416e-05
410 7.483728404622525e-05
411 7.313604874070734e-05
412 7.18058945494704e-05
413 7.047194230835885e-05
414 6.951852992642671e-05
415 6.83324396959506e-05
416 6.696346827084199e-05
417 6.566964293597266e-05
418 6.451181252487004e-05
419 6.359924009302631e-05
420 6.270920130191371e-05
421 6.18329577264376e-05
422 6.0633523389697075e-05
423 5.9861431509489194e-05
424 5.869346932740882e-05
425 5.7455359637970105e-05
426 5.664665513904765e-05
427 5.571632573264651e-05
428 5.492466880241409e-05
429 5.407247590483166e-05
430 5.305188460624777e-05
431 5.2117604354862124e-05
432 5.116728789289482e-05
433 5.0633021601242945e-05
434 4.97593