In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [3]:
import torch

dtype = torch.float
device = torch.device("cuda:0")

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    # Backprop to compute gradients of w1 and w2 w.r.t loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 22986976.0
1 18275036.0
2 16870226.0
3 16666508.0
4 16265776.0
5 14959135.0
6 12524398.0
7 9563903.0
8 6711250.0
9 4483265.0
10 2921371.5
11 1921702.0
12 1299297.75
13 917525.1875
14 678915.75
15 525509.125
16 422336.40625
17 349654.0
18 295939.9375
19 254529.765625
20 221461.40625
21 194309.296875
22 171561.75
23 152251.203125
24 135648.21875
25 121247.0625
26 108683.015625
27 97657.8046875
28 87953.5546875
29 79390.46875
30 71815.5
31 65089.375
32 59094.3671875
33 53734.078125
34 48934.39453125
35 44630.234375
36 40760.234375
37 37271.890625
38 34122.984375
39 31276.166015625
40 28699.8125
41 26356.67578125
42 24229.48828125
43 22296.62890625
44 20537.19921875
45 18934.671875
46 17472.419921875
47 16136.0703125
48 14914.0439453125
49 13795.4755859375
50 12771.984375
51 11835.0966796875
52 10975.107421875
53 10184.3515625
54 9456.9873046875
55 8782.875
56 8162.0185546875
57 7590.12109375
58 7062.4970703125
59 6575.302734375
60 6125.1494140625
61 5708.9892578125
62 5323.916015625
63 

384 0.0008886042051017284
385 0.0008606893243268132
386 0.0008314838050864637
387 0.0008040011744014919
388 0.0007777330465614796
389 0.0007547168061137199
390 0.0007307775085791945
391 0.000708183622919023
392 0.0006847763434052467
393 0.0006642260123044252
394 0.0006440814468078315
395 0.0006233673775568604
396 0.0006059232400730252
397 0.0005876889917999506
398 0.000570912379771471
399 0.0005530699854716659
400 0.0005366089171729982
401 0.0005207810318097472
402 0.0005055778892710805
403 0.0004910661955364048
404 0.0004771137610077858
405 0.00046402827138081193
406 0.0004510499711614102
407 0.0004375100543256849
408 0.00042629108065739274
409 0.0004152172477915883
410 0.0004036572645418346
411 0.0003924361080862582
412 0.0003820806450676173
413 0.00037174453609623015
414 0.0003618387854658067
415 0.0003517276491038501
416 0.0003426604380365461
417 0.00033453290234319866
418 0.00032606814056634903
419 0.00031805099570192397
420 0.00030942680314183235
421 0.0003021321026608348
422 0.0