In [None]:
%matplotlib inline


PyTorch: Tensors and autograd
-------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Tensors, and uses PyTorch autograd to compute gradients.


A PyTorch Tensor represents a node in a computational graph. If ``x`` is a
Tensor that has ``x.requires_grad=True`` then ``x.grad`` is another Tensor
holding the gradient of ``x`` with respect to some scalar value.



In [4]:
import torch

dtype = torch.float
device = torch.device("cuda:0")

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights and set gradients compute
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    # Backprop to compute gradients of w1 and w2 w.r.t loss
    loss.backward()
    
    # Update weights
    with torch.no_grad():
        w1 -=learning_rate * w1.grad
        w2 -=learning_rate * w2.grad
        
        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 32168110.0
1 29678516.0
2 30679060.0
3 30371336.0
4 25980168.0
5 18357526.0
6 10978420.0
7 5986260.0
8 3303671.0
9 1984930.75
10 1336795.75
11 992613.5
12 787797.0625
13 650478.5
14 549386.125
15 470292.4375
16 406132.5625
17 352911.6875
18 308270.34375
19 270412.65625
20 238064.125
21 210312.4375
22 186358.453125
23 165591.109375
24 147511.5625
25 131720.71875
26 117881.78125
27 105721.84375
28 95005.265625
29 85532.6953125
30 77142.578125
31 69698.2265625
32 63071.5625
33 57161.77734375
34 51880.32421875
35 47151.5703125
36 42909.28125
37 39100.703125
38 35682.796875
39 32603.427734375
40 29823.548828125
41 27309.236328125
42 25032.517578125
43 22968.25390625
44 21094.0546875
45 19390.359375
46 17839.359375
47 16426.267578125
48 15137.72265625
49 13961.349609375
50 12888.244140625
51 11907.923828125
52 11010.5517578125
53 10187.2529296875
54 9431.78125
55 8738.27734375
56 8101.1943359375
57 7515.47705078125
58 6976.2783203125
59 6481.3935546875
60 6024.763671875
61 5603.33154296875

390 0.0005168694187887013
391 0.000502896960824728
392 0.0004885565722361207
393 0.0004743315512314439
394 0.00046067938092164695
395 0.0004478106275200844
396 0.00043537476449273527
397 0.0004220565897412598
398 0.00041053033783100545
399 0.00039936782559379935
400 0.00038853613659739494
401 0.000378349213860929
402 0.0003683097893372178
403 0.00035901612136512995
404 0.000348952307831496
405 0.0003397812251932919
406 0.0003307886072434485
407 0.00032176237436942756
408 0.00031399534782394767
409 0.0003053421387448907
410 0.0002978391421493143
411 0.00029077904764562845
412 0.00028318766271695495
413 0.0002752986620180309
414 0.00026860018260776997
415 0.0002623535692691803
416 0.00025633457698859274
417 0.00025006805662997067
418 0.00024417060194537044
419 0.0002378867648076266
420 0.00023158182739280164
421 0.00022659519163426012
422 0.00022158975480124354
423 0.00021639597252942622
424 0.00021122326143085957
425 0.0002063649008050561
426 0.00020130339544266462
427 0.000196602952200