In [None]:
%matplotlib inline


Warm-up: numpy
--------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.

This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.



In [1]:
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 30368427.89502176
1 24596280.07574503
2 21634727.07359833
3 18710124.589152154
4 14965810.281355958
5 10995620.786856908
6 7490364.51516824
7 4928529.698373835
8 3229250.8289861595
9 2176453.3110217587
10 1531782.9290568922
11 1132846.7227820947
12 876159.9691349426
13 703450.4816085335
14 580711.8402883159
15 489300.15241233294
16 418352.8106328645
17 361620.95582767506
18 315381.7150116408
19 276817.46257662144
20 244209.75320620328
21 216299.44594541026
22 192309.2621413725
23 171566.81343089943
24 153536.5452315146
25 137796.3032426593
26 124003.10066256963
27 111854.3172900539
28 101134.46017496486
29 91631.92805661797
30 83169.54714288596
31 75622.38384551098
32 68877.78089086813
33 62833.931537752964
34 57404.997281342614
35 52519.394245994015
36 48110.95648380505
37 44129.68796179197
38 40526.54274342889
39 37260.47381499283
40 34294.52957533554
41 31595.70293193597
42 29141.217901233045
43 26902.80182048758
44 24860.332890880825
45 22997.26662883165
46 21294.03015721282
47 1