In [None]:
%matplotlib inline


PyTorch: Tensors
----------------

这次我们使用PyTorch tensors来创建前向神经网络，计算损失，以及反向传播。

一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是，PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算，就需要把Tensor换成cuda类型。


In [1]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 26888820.0
1 20979944.0
2 19380022.0
3 18957662.0
4 18105740.0
5 15903030.0
6 12617097.0
7 9022224.0
8 5997497.5
9 3812973.25
10 2415032.75
11 1563653.5
12 1057986.75
13 753645.0
14 565461.0
15 443420.46875
16 360100.625
17 300133.21875
18 254813.6875
19 219264.0625
20 190485.3125
21 166645.234375
22 146664.4375
23 129661.5859375
24 115049.53125
25 102405.5390625
26 91405.4921875
27 81788.28125
28 73350.3125
29 65918.3203125
30 59366.5
31 53569.9765625
32 48424.328125
33 43854.46875
34 39777.37109375
35 36129.16015625
36 32867.421875
37 29941.541015625
38 27312.00390625
39 24945.33203125
40 22810.458984375
41 20885.125
42 19142.91015625
43 17564.587890625
44 16134.7822265625
45 14835.9150390625
46 13653.8330078125
47 12577.2900390625
48 11594.9423828125
49 10698.267578125
50 9878.6494140625
51 9128.45703125
52 8441.27734375
53 7811.24169921875
54 7233.18408203125
55 6702.48388671875
56 6214.18603515625
57 5764.9599609375
58 5351.634765625
59 4970.8505859375
60 4619.44287109375
61 429

434 9.584150393493474e-05
435 9.391656203661114e-05
436 9.165167284663767e-05
437 8.984629675978795e-05
438 8.805704419501126e-05
439 8.642551256343722e-05
440 8.451484609395266e-05
441 8.300202171085402e-05
442 8.138606062857434e-05
443 7.980818918440491e-05
444 7.837548764655367e-05
445 7.705082680331543e-05
446 7.567813736386597e-05
447 7.401522452710196e-05
448 7.298204582184553e-05
449 7.152874604798853e-05
450 7.034096779534593e-05
451 6.91091627231799e-05
452 6.780566764064133e-05
453 6.648786802543327e-05
454 6.545426731463522e-05
455 6.427622429328039e-05
456 6.304256385192275e-05
457 6.205030513228849e-05
458 6.1020426073810086e-05
459 5.995111496304162e-05
460 5.881515608052723e-05
461 5.7683875638758764e-05
462 5.670229802490212e-05
463 5.538931873161346e-05
464 5.449245873023756e-05
465 5.3626226872438565e-05
466 5.276917727314867e-05
467 5.19685163453687e-05
468 5.1210750825703144e-05
469 5.04882846144028e-05
470 4.9606755055719987e-05
471 4.897603139397688e-05
472 4.7940