Same as before, but now with PyTorch tensors.
The biggest difference between a numpy array and a PyTorch Tensor, is that a PyTorch Tensor can run on either CPU or GPU.

In [1]:
import torch

In [2]:
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

In [4]:
# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

In [5]:
# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

In [6]:
learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 49237888.0
1 57651568.0
2 64022900.0
3 49465616.0
4 23452348.0
5 7786769.0
6 3143276.25
7 1937381.25
8 1479062.75
9 1204780.75
10 1002841.0
11 844103.125
12 716419.875
13 612394.125
14 526802.375
15 455701.125
16 396288.84375
17 346273.0
18 303848.4375
19 267657.75
20 236626.375
21 209919.65625
22 186851.03125
23 166802.390625
24 149299.078125
25 133941.5
26 120425.953125
27 108491.046875
28 97935.015625
29 88556.953125
30 80210.6015625
31 72772.6953125
32 66120.1015625
33 60158.05859375
34 54803.390625
35 49984.73046875
36 45643.484375
37 41724.328125
38 38177.76953125
39 34967.33984375
40 32057.76953125
41 29419.369140625
42 27021.453125
43 24840.07421875
44 22852.5625
45 21040.1796875
46 19386.25
47 17874.103515625
48 16490.630859375
49 15224.4951171875
50 14066.0
51 13005.970703125
52 12033.58203125
53 11139.759765625
54 10319.3388671875
55 9564.4052734375
56 8869.0615234375
57 8228.564453125
58 7637.8662109375
59 7092.841796875
60 6589.7119140625
61 6124.9638671875
62 5695.52441