In [1]:
%matplotlib inline


PyTorch: Tensors
----------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.  
这个实现使用PyTorch的张量来手动计算前向误差及反向传播。

A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.

The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.  
numpy 数组和 PyTorch 张量之间的最大不同是 PyTorch 张量能够在CPU和GPU上运行。为了让其在GPU上运行，只需要把张量转换成 cuda 类型的数据。



In [4]:
import torch
import time


# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6

time_begin = time.time()

for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

time_end = time.time()
print(time_end - time_begin)

0 34047534.01889758
1 35599780.620915055
2 43260161.783258915
3 47121863.77586162
4 39585535.2887817
5 22736705.609122515
6 9805321.547041506
7 3957172.938485967
8 1983917.8761560265
9 1291805.8497694135
10 986200.3991556331
11 806225.5738645177
12 678216.7914754432
13 578435.0393936131
14 497297.216810961
15 430186.5057797349
16 374027.79026160017
17 326607.48654790944
18 286295.5938097512
19 251756.8847007178
20 222110.23693604209
21 196586.08992756082
22 174498.93558853038
23 155277.6075473379
24 138503.8416155039
25 123825.54191570605
26 110928.96115449816
27 99567.0299982719
28 89533.95665663492
29 80658.58668093581
30 72793.63124886528
31 65803.67803303571
32 59580.221640136326
33 54032.33214779396
34 49077.357100771085
35 44633.37268998643
36 40643.71822310857
37 37057.36137744257
38 33828.034584160894
39 30907.95714095764
40 28271.36220862844
41 25887.58058714599
42 23728.949903042885
43 21773.656664871472
44 19996.937812568787
45 18382.299514061597
46 16913.255761062945
47 155

467 7.732178824598757e-05
468 7.653485196185486e-05
469 7.520528624349598e-05
470 7.402346920453232e-05
471 7.283289960744246e-05
472 7.171484421619745e-05
473 7.059591574741508e-05
474 6.930171079471206e-05
475 6.839036592347358e-05
476 6.734260681344473e-05
477 6.636604497218049e-05
478 6.522095178093656e-05
479 6.407732164769031e-05
480 6.320740504682856e-05
481 6.228864587036476e-05
482 6.134841254071266e-05
483 6.055338404634472e-05
484 5.9643823104901794e-05
485 5.88304768895237e-05
486 5.8126769043306825e-05
487 5.718578293452403e-05
488 5.630236546671011e-05
489 5.544741151553012e-05
490 5.4871072831228634e-05
491 5.403721306315723e-05
492 5.3366757183564983e-05
493 5.2790806637309506e-05
494 5.22314235321987e-05
495 5.1297501885011476e-05
496 5.074497455357166e-05
497 4.989760323308168e-05
498 4.9397276723456285e-05
499 4.861034117767393e-05
0.6623139381408691
