In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 38173304.0
1 38483604.0
2 43378496.0
3 42389120.0
4 31556136.0
5 16729511.0
6 7239437.0
7 3175119.0
8 1722849.0
9 1160596.0
10 891473.5625
11 727375.25
12 610076.25
13 519243.28125
14 446070.0625
15 386015.96875
16 336088.625
17 294139.625
18 258644.328125
19 228354.578125
20 202385.03125
21 180079.828125
22 160749.546875
23 143980.421875
24 129306.8046875
25 116426.0390625
26 105069.3828125
27 95023.5703125
28 86117.8125
29 78210.578125
30 71150.53125
31 64812.42578125
32 59129.86328125
33 54025.51171875
34 49426.5
35 45274.1875
36 41520.5859375
37 38120.6796875
38 35035.98828125
39 32234.283203125
40 29684.71875
41 27362.130859375
42 25243.365234375
43 23307.2890625
44 21537.58984375
45 19917.400390625
46 18431.416015625
47 17067.71875
48 15815.24609375
49 14663.9189453125
50 13604.16796875
51 12628.0595703125
52 11728.14453125
53 10898.2109375
54 10130.89453125
55 9421.8310546875
56 8766.3115234375
57 8160.43701171875
58 7599.8232421875
59 7080.8974609375
60 6599.900390625
61 6153

413 0.000518787419423461
414 0.0005046242149546742
415 0.0004926452529616654
416 0.00048042749403975904
417 0.00046816616668365896
418 0.0004565080162137747
419 0.0004457243485376239
420 0.00043434277176856995
421 0.0004238742694724351
422 0.0004137624055147171
423 0.0004053924058098346
424 0.00039535254472866654
425 0.0003857314877677709
426 0.0003767538582906127
427 0.0003677633940242231
428 0.00035926554119214416
429 0.00035141498665325344
430 0.0003433499950915575
431 0.00033565517514944077
432 0.00032861344516277313
433 0.00032135078799910843
434 0.0003142760251648724
435 0.00030769803561270237
436 0.0003007282211910933
437 0.000294322642730549
438 0.0002877979713957757
439 0.00028261609259061515
440 0.00027666916139423847
441 0.00027011800557374954
442 0.00026487073046155274
443 0.0002591066586319357
444 0.00025366625050082803
445 0.00024867619504220784
446 0.0002442584664095193
447 0.00023962580598890781
448 0.00023519070236943662
449 0.00023008353309705853
450 0.000225841926294