In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

**The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.**



In [4]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 28613416.0
1 19649098.0
2 13879484.0
3 9715483.0
4 6757820.0
5 4730265.0
6 3378470.75
7 2482085.75
8 1880921.0
9 1466421.25
10 1171524.125
11 954645.6875
12 790775.6875
13 663054.375
14 561562.375
15 479363.4375
16 411942.25
17 356016.0
18 309130.125
19 269523.53125
20 235862.453125
21 207155.546875
22 182510.3125
23 161258.140625
24 142866.875
25 126897.4609375
26 112988.171875
27 100832.7578125
28 90173.4921875
29 80808.046875
30 72544.2421875
31 65245.24609375
32 58806.3828125
33 53101.18359375
34 48029.40234375
35 43508.8515625
36 39472.24609375
37 35863.28515625
38 32625.638671875
39 29720.2109375
40 27108.8828125
41 24757.083984375
42 22636.560546875
43 20721.71875
44 18989.96484375
45 17421.544921875
46 15999.1875
47 14707.86328125
48 13534.53125
49 12466.416015625
50 11492.8154296875
51 10604.744140625
52 9793.8876953125
53 9052.4619140625
54 8373.8095703125
55 7752.1728515625
56 7182.484375
57 6659.37841796875
58 6178.7294921875
59 5736.68994140625
60 5329.79345703125
61 495

444 0.00028729334007948637
445 0.0002811662561725825
446 0.0002746893442235887
447 0.0002678956079762429
448 0.00026274833362549543
449 0.0002568237541709095
450 0.00025139909121207893
451 0.00024670493439771235
452 0.00024123169714584947
453 0.00023661907471250743
454 0.00023205665638670325
455 0.00022674095816910267
456 0.00022214400814846158
457 0.0002181239251513034
458 0.00021402929269243032
459 0.00020947889424860477
460 0.0002052340132649988
461 0.00020110694458708167
462 0.0001974245096789673
463 0.00019367666391190141
464 0.00018961662135552615
465 0.00018634159641806036
466 0.00018264632672071457
467 0.0001788699591998011
468 0.00017583543376531452
469 0.0001724396861391142
470 0.00016921339556574821
471 0.00016568797582294792
472 0.00016280767158605158
473 0.00015989478561095893
474 0.0001568773586768657
475 0.0001542574173072353
476 0.0001516686606919393
477 0.00014857275527901947
478 0.00014568038750439882
479 0.00014358243788592517
480 0.0001407347444910556
481 0.00013825