使用NumPy，手动实现网络的前向和反向传播，来拟合随机数据

In [9]:
import numpy as np

# 依次为：批尺寸参数、输入维度、隐藏层维度、输出维度
N, D_in, H, D_out = 64, 1000, 100, 10

# 产生随即输入和输出数据
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# 随即初始化权重
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # 前向传播，计算预测值
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)
    
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    # 反向传播，计算w1、w2对loss的梯度
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 27201978.113615196
1 21170252.74999618
2 17277485.03405474
3 13806464.23950237
4 10495373.352864
5 7608491.65449019
6 5332744.660555856
7 3702563.446750896
8 2593573.642414475
9 1863618.0987471223
10 1382660.8246212974
11 1062039.4518933836
12 841936.9068820556
13 686002.3067041485
14 571479.2609183869
15 484452.8953394148
16 416095.8343561197
17 361062.65563149453
18 315805.3704486707
19 278000.5767565679
20 246001.01210356908
21 218645.61877446546
22 195050.74535285903
23 174584.75656639657
24 156714.76402617816
25 141033.07576563
26 127216.15201089438
27 115005.31868909133
28 104169.3111894581
29 94524.2247074908
30 85924.9325417742
31 78232.77738934135
32 71331.83616037625
33 65128.34782979985
34 59538.53166611515
35 54493.35515485745
36 49933.84672799064
37 45805.038932310854
38 42060.99842263029
39 38663.75653995537
40 35572.19439236564
41 32757.00383631389
42 30190.42959374978
43 27850.100401451084
44 25711.956948278726
45 23755.822699034663
46 21964.536045313205
47 20322.4277

420 7.172896644019545e-05
421 6.844705698407037e-05
422 6.531627902959549e-05
423 6.232924742088541e-05
424 5.9478490464762413e-05
425 5.675977694780438e-05
426 5.416593380725865e-05
427 5.168955227568112e-05
428 4.932727575616281e-05
429 4.7074646189496535e-05
430 4.4923894851451826e-05
431 4.287120797450241e-05
432 4.091270767250888e-05
433 3.904536300726166e-05
434 3.726196162762611e-05
435 3.5560290722143873e-05
436 3.3936994741341146e-05
437 3.238780001810682e-05
438 3.090920614506564e-05
439 2.949871485929832e-05
440 2.8153136567733162e-05
441 2.6868886898472002e-05
442 2.5643055129932115e-05
443 2.4473573658353893e-05
444 2.3357198768348176e-05
445 2.2291935980322558e-05
446 2.1275427524007097e-05
447 2.0305904737085888e-05
448 1.9380353781303963e-05
449 1.849712548955223e-05
450 1.7654567720891034e-05
451 1.684995493918036e-05
452 1.608210231565809e-05
453 1.534952932906891e-05
454 1.4650294649452702e-05
455 1.3982914925466771e-05
456 1.3346016320607996e-05
457 1.27385485017188