In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [5]:
np.random.seed(666)
m = 1000
n = 10
X = np.random.random(size = (m,n))

true_theta = np.arange(1,n+2, dtype = 'float')

In [6]:
ones = np.ones((X.shape[0], 1))
X_b = np.hstack([ones, X])
y = X_b.dot(true_theta) + np.random.normal(size = m)

In [7]:
print(X.shape)
print(y.shape)
print(true_theta)

(1000, 10)
(1000,)
[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]


In [13]:
def J(theta, X_b, y):  # 损失函数J，其中X_b是已经加了第一列是1
    try:
        y_hat = X_b.dot(theta)
        return np.sum((y - y_hat) ** 2) / len(X_b)
    except:
        return float('inf')

In [14]:
def dJ_math(theta, X_b, y):
    return X_b.T.dot(X_b.dot(theta) - y) * 2. / len(y)

In [15]:
def dJ_debug(theta, X_b, y, epsilon = 0.01):
    res = np.empty(len(theta))
    for i in range(len(theta)):
        theta_1 = theta.copy()
        theta_1[i] += epsilon
        theta_2 = theta.copy()
        theta_2[i] -= epsilon
        res[i] = (J(theta_1, X_b, y) - J(theta_2, X_b, y)) / (2. * epsilon)
    return res

### 将函数传入函数！
dJ in gradient_descent()

In [16]:
def gradient_descent(dJ, X_b, y, initial_theta, eta, n_iters = 1e4, epsilon = 1e-8):
    theta = initial_theta
    i_iter = 0
    while i_iter <= n_iters:
        gradient = dJ(theta, X_b, y)
        last_theta = theta
        theta = theta + (-1) * eta * gradient
        if(abs( J(theta, X_b, y) - J(last_theta, X_b, y) ) < epsilon ):
            break
        i_iter += 1
        
    return theta

In [17]:
initial_theta = np.zeros(X_b.shape[1])
eta = 0.01

%time theta = gradient_descent(dJ_debug, X_b, y, initial_theta, eta)
theta

CPU times: user 2.78 s, sys: 5.24 ms, total: 2.79 s
Wall time: 2.79 s


array([ 1.12512892,  2.05312927,  2.91523174,  4.11896693,  5.05002716,
        5.90494571,  6.97384418,  8.00088859,  8.86214178,  9.98608873,
       10.90529852])

In [18]:
%time theta = gradient_descent(dJ_math, X_b, y, initial_theta, eta)
theta

CPU times: user 2.44 s, sys: 11 ms, total: 2.45 s
Wall time: 411 ms


array([ 1.12512892,  2.05312927,  2.91523174,  4.11896693,  5.05002716,
        5.90494571,  6.97384418,  8.00088859,  8.86214178,  9.98608873,
       10.90529852])

### dJ_debug:
1. 可以得到正确答案，作为一个numerical方法
2. 需要的时间久
3. 应用：先用dJ_debug, 求出正确答案，自己推导math方法，用dJ_debug的结果来检验math方法对不对。
4. dJ_debug适用于所有情况，但是dJ_math对于不同的损失函数闭式解不同