### 作業目標: 使用Pytorch進行微分與倒傳遞
這份作業我們會實作微分與倒傳遞以及使用Pytorch的Autograd。

### 使用Pytorch實作微分與倒傳遞

這裡我們很簡單的實作兩層的神經網路進行回歸問題，其中loss function為L2 loss

$$
L2\_loss = (y_{pred}-y)^2
$$

兩層經網路如下所示
$$
y_{pred} = ReLU(XW_1)W_2
$$

In [1]:
import torch
device = torch.device('cpu')

In [2]:
# N: batch size
# D_in: input dimension
# H: hidden dimension
# D_out: output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

torch.manual_seed(1234)

# 隨機生成x(intput), y(output)
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# 初始化 weight W1, W2
w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)

# 設置learning rate
learning_rate = 1e-6

# 訓練500個epoch
for t in range(500):
    # 向前傳遞: 計算y_pred
    h = x.mm(w1)    # (N, D_in)*(D_in, H) => (N, H)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)    # (N, H)*(H, D_out) => (N, D_out)
    
    # 計算loss  
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    # 倒傳遞: 計算 W1 與 W2 對 loss 的微分(梯度)
    # d_loss/d_w1, d_loss/d_w2
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # 參數更新
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 24151626.0
1 17373814.0
2 14810979.0
3 13811587.0
4 13237697.0
5 12380204.0
6 11037116.0
7 9215422.0
8 7263422.5
9 5425220.0
10 3929178.5
11 2792755.0
12 1986375.5
13 1426316.75
14 1045155.0
15 783794.75
16 603579.0
17 476743.125
18 385634.53125
19 318417.0625
20 267663.3125
21 228307.625
22 197113.375
23 171859.65625
24 151016.75
25 133559.875
26 118731.2265625
27 106010.7890625
28 95021.5390625
29 85452.75
30 77058.859375
31 69648.5625
32 63078.078125
33 57228.92578125
34 52008.8671875
35 47334.9453125
36 43141.6171875
37 39370.1484375
38 35973.5859375
39 32910.20703125
40 30138.205078125
41 27625.6953125
42 25346.88671875
43 23277.66796875
44 21396.4921875
45 19684.412109375
46 18123.9140625
47 16703.755859375
48 15411.8916015625
49 14232.470703125
50 13154.0234375
51 12164.80078125
52 11257.572265625
53 10424.779296875
54 9659.8408203125
55 8956.8740234375
56 8309.5546875
57 7713.3134765625
58 7164.28125
59 6657.83447265625
60 6190.37744140625
61 5758.98486328125
62 5360.67285156

389 0.000979539705440402
390 0.0009501409367658198
391 0.0009216705802828074
392 0.0008944420260377228
393 0.0008695028955116868
394 0.0008428359287790954
395 0.0008178879506886005
396 0.000794422288890928
397 0.0007711535436101258
398 0.0007484162342734635
399 0.0007286592153832316
400 0.0007092966698110104
401 0.0006881424924358726
402 0.0006696760538034141
403 0.0006513833068311214
404 0.0006336182705126703
405 0.0006168068503029644
406 0.000600655737798661
407 0.000584747816901654
408 0.000569107651244849
409 0.0005539483390748501
410 0.0005392478778958321
411 0.0005256481817923486
412 0.0005124998278915882
413 0.0005001454846933484
414 0.00048697550664655864
415 0.0004747537604998797
416 0.0004642282729037106
417 0.0004530047590378672
418 0.0004412662819959223
419 0.0004307531926315278
420 0.00041952868923544884
421 0.0004098778299521655
422 0.0004000063636340201
423 0.00038985832361504436
424 0.0003805737942457199
425 0.0003723420377355069
426 0.00036328122951090336
427 0.0003549

### 使用Pytorch的Autograd

In [3]:
import torch
device = torch.device('cpu')

In [4]:
# N: batch size
# D_in: input dimension
# H: hidden dimension
# D_out: output dimension
N, D_in, H, D_out = 64, 1000, 100, 10
torch.manual_seed(1234)

# 隨機生成x, y
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# 初始化weight W1, W2
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

# 設置learning rate
learning_rate = 1e-6

# 訓練500個epoch
for t in range(500):
    # 向前傳遞: 計算y_pred
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # 計算loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # 倒傳遞: 計算W1與W2對loss的微分(梯度)
    loss.backward()

    # 參數更新
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        w1.grad.zero_()
        w2.grad.zero_()

0 24151626.0
1 17373814.0
2 14810979.0
3 13811587.0
4 13237697.0
5 12380204.0
6 11037116.0
7 9215422.0
8 7263422.5
9 5425220.0
10 3929178.5
11 2792755.0
12 1986375.5
13 1426316.75
14 1045155.0
15 783794.75
16 603579.0
17 476743.125
18 385634.53125
19 318417.0625
20 267663.3125
21 228307.625
22 197113.375
23 171859.65625
24 151016.75
25 133559.875
26 118731.2265625
27 106010.7890625
28 95021.5390625
29 85452.75
30 77058.859375
31 69648.5625
32 63078.078125
33 57228.92578125
34 52008.8671875
35 47334.9453125
36 43141.6171875
37 39370.1484375
38 35973.5859375
39 32910.20703125
40 30138.205078125
41 27625.6953125
42 25346.88671875
43 23277.66796875
44 21396.4921875
45 19684.412109375
46 18123.9140625
47 16703.755859375
48 15411.8916015625
49 14232.470703125
50 13154.0234375
51 12164.80078125
52 11257.572265625
53 10424.779296875
54 9659.8408203125
55 8956.8740234375
56 8309.5546875
57 7713.3134765625
58 7164.28125
59 6657.83447265625
60 6190.37744140625
61 5758.98486328125
62 5360.67285156

410 0.0005392478778958321
411 0.0005256481817923486
412 0.0005124998278915882
413 0.0005001454846933484
414 0.00048697550664655864
415 0.0004747537604998797
416 0.0004642282729037106
417 0.0004530047590378672
418 0.0004412662819959223
419 0.0004307531926315278
420 0.00041952868923544884
421 0.0004098778299521655
422 0.0004000063636340201
423 0.00038985832361504436
424 0.0003805737942457199
425 0.0003723420377355069
426 0.00036328122951090336
427 0.0003549797984305769
428 0.0003470449591986835
429 0.0003386019088793546
430 0.0003305350837763399
431 0.0003235472831875086
432 0.00031560868956148624
433 0.00030900444835424423
434 0.0003025565529242158
435 0.00029531086329370737
436 0.00028917231247760355
437 0.00028322706930339336
438 0.0002773759188130498
439 0.0002713072462938726
440 0.0002656866854522377
441 0.0002600733714643866
442 0.0002549859054852277
443 0.0002493115025572479
444 0.0002442499389871955
445 0.00023879083164501935
446 0.00023410988796968013
447 0.00022976042237132788
