In [2]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [3]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 34105728.0
1 31627340.0
2 33199556.0
3 32789760.0
4 27412030.0
5 18509252.0
6 10525968.0
7 5551514.0
8 3069181.5
9 1908437.875
10 1345917.375
11 1041337.25
12 852534.1875
13 720376.3125
14 619650.125
15 538774.875
16 471927.0625
17 415744.53125
18 367936.46875
19 326954.4375
20 291583.9375
21 260888.46875
22 234088.484375
23 210603.796875
24 189947.203125
25 171702.875
26 155553.75
27 141221.0
28 128449.9609375
29 117076.0
30 106894.3515625
31 97745.4140625
32 89513.6328125
33 82084.0078125
34 75367.46875
35 69281.8125
36 63760.05859375
37 58766.234375
38 54219.0390625
39 50073.80859375
40 46288.2890625
41 42825.58984375
42 39655.2421875
43 36752.12890625
44 34086.9453125
45 31637.322265625
46 29384.787109375
47 27312.388671875
48 25403.580078125
49 23641.8984375
50 22014.943359375
51 20511.599609375
52 19121.033203125
53 17833.708984375
54 16641.0234375
55 15536.01953125
56 14511.43359375
57 13560.015625
58 12676.267578125
59 11854.85546875
60 11091.3486328125
61 10381.8740234375
62