In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.



In [2]:
import torch


dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 46134944.0
1 57460364.0
2 79499288.0
3 82400768.0
4 47455252.0
5 13316391.0
6 3532003.75
7 1936640.75
8 1486554.5
9 1215765.875
10 1011070.3125
11 849523.6875
12 719553.8125
13 613963.0625
14 527411.8125
15 455556.375
16 395657.5625
17 345363.15625
18 302815.84375
19 266607.65625
20 235631.53125
21 208977.6875
22 185929.671875
23 165893.203125
24 148412.578125
25 133116.703125
26 119680.4453125
27 107837.8828125
28 97372.5546875
29 88092.96875
30 79819.5390625
31 72447.8515625
32 65863.78125
33 59967.94921875
34 54672.49609375
35 49921.56640625
36 45643.0625
37 41780.83984375
38 38288.11328125
39 35121.23828125
40 32251.611328125
41 29644.619140625
42 27277.041015625
43 25120.974609375
44 23159.6015625
45 21368.3203125
46 19731.291015625
47 18233.240234375
48 16861.095703125
49 15603.6728515625
50 14449.1904296875
51 13389.19921875
52 12414.0263671875
53 11516.07421875
54 10688.1845703125
55 9925.158203125
56 9221.197265625
57 8571.1162109375
58 7970.8037109375
59 7416.07373046875
60

379 0.0006691119051538408
380 0.0006492964457720518
381 0.0006278865039348602
382 0.0006077482248656452
383 0.0005886572762392461
384 0.0005716768791899085
385 0.0005563989980146289
386 0.0005396514316089451
387 0.0005232192925177515
388 0.0005074329674243927
389 0.0004926634137518704
390 0.0004800721362698823
391 0.00046602587099187076
392 0.0004530635487753898
393 0.0004402756749186665
394 0.0004286105395294726
395 0.00041670826612971723
396 0.00040480101597495377
397 0.00039359580841846764
398 0.000383266102289781
399 0.00037291436456143856
400 0.0003637511981651187
401 0.00035427979310043156
402 0.000344447122188285
403 0.000334419310092926
404 0.0003267258871346712
405 0.0003183324879501015
406 0.00030969135696068406
407 0.000302841974189505
408 0.0002950388006865978
409 0.0002879663370549679
410 0.00028056485462002456
411 0.00027359212981536984
412 0.0002662852348294109
413 0.0002599635918159038
414 0.000253837468335405
415 0.000247338495682925
416 0.00024181186745408922
417 0.00