### 梯度下降法调试参数

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [8]:
# 1.准备数据
np.random.seed(666)
X = np.random.random(size = (1000,10))
true_theta = np.arange(1,12,dtype=float)
X_b = np.hstack([np.ones((len(X),1)),X])
y = X_b.dot(true_theta)

In [9]:
# 2.求损失函数
def J(theta, X_b, y):
    try:
        return np.sum((y - X_b.dot(theta))**2) / len(X_b)
    except:
        return float('inf')

In [10]:
# 3.用求损失函数偏导数的方法求 theta
def dJ_math(theta, X_b, y):
    return X_b.T.dot(X_b.dot(theta) - y) * 2. / len(y)

In [34]:
# 4.用导数定义式求出偏导
def dJ_debug(theta, X_b, y, epsilon=0.01):
    res = np.empty(len(theta))
    for i in range(len(theta)):
        #制造与 theta 无限接近的两个点
        theta_1 = theta.copy()
        theta_1[i] += epsilon
        theta_2 = theta.copy()
        theta_2[i] -= epsilon
        res[i] = (J(theta_1, X_b, y) - J(theta_2, X_b, y)) / (2 * epsilon)
    return res

In [35]:
# 5.求出 theta ,这里使用的批量梯度下降法
#参数 dJ是一个方法，这里分别用两种方法求出theta
def gradient_descent(dJ, X_b, y, initial_theta, eta, n_iters = 1e4, epsilon=1e-8):
    theta = initial_theta
    cur_iter = 0
    while cur_iter<n_iters:
        last_theta = theta
        gradient = dJ(theta,X_b,y)
        theta = theta - gradient*eta
        if(abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):
            break
    cur_iter+=1
    return theta
        

In [36]:
X_b = np.hstack([np.ones((len(X),1)),X])
initial_theta = np.zeros(X_b.shape[1])
eta = 0.01

In [39]:
# 6.使用数学公式法
%time theta = gradient_descent(dJ_math, X_b, y, initial_theta, eta)
theta

CPU times: user 1.98 s, sys: 47.9 ms, total: 2.03 s
Wall time: 1.02 s


array([ 1.01729299,  1.99772027,  2.99619305,  3.99592825,  4.9966377 ,
        5.99705082,  6.99621895,  7.99723771,  8.99600862,  9.99695692,
       10.99632255])

In [40]:
# 7.使用求导定时式法
%time theta = gradient_descent(dJ_debug, X_b, y, initial_theta, eta)
theta

CPU times: user 15.6 s, sys: 368 ms, total: 15.9 s
Wall time: 7.97 s


array([ 1.01729299,  1.99772027,  2.99619305,  3.99592825,  4.9966377 ,
        5.99705082,  6.99621895,  7.99723771,  8.99600862,  9.99695692,
       10.99632255])