In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
np.random.seed(666)
X = np.random.random(size=(1000, 10))

true_theta = np.arange(1, 12, dtype=float)
Xb = np.hstack([np.ones((len(X), 1)), X])
y = Xb.dot(true_theta) + np.random.normal(size=1000)

In [3]:
true_theta

array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

In [4]:
# 目标函数（损失函数），theta为待求解参数，Xb，y为目标函数的系数
def J(theta, Xb, y):
    try:
        # return np.sum((y - Xb@theta)**2)/len(Xb)
        # 使用完全向量化计算更快一点
        tmp = y - Xb@theta
        return (tmp@tmp)/len(Xb)
    except:
        return float('inf')

In [5]:
def dJ_math(theta, Xb, y):
    return Xb.T@(Xb@theta - y)*2/len(y)

In [6]:
def dJ_debug(theta, Xb, y, epsilon=0.01):
    res = np.empty(len(theta))
    for i in range(len(theta)):
        theta_1 = theta.copy()
        theta_1[i] += epsilon
        theta_2 = theta.copy()
        theta_2[i] -= epsilon
        res[i] = (J(theta_1, Xb, y) - J(theta_2, Xb, y)) / (2 * epsilon)
    return res

In [7]:
 def gradient_descent(dJ, Xb, y, initial_theta, eta, n_iters = 1e4, epsilon=1e-8):
    
    theta = initial_theta
    i_iter = 0

    while i_iter < n_iters:
        gradient = dJ(theta, Xb, y)
        last_theta = theta

        theta = theta - eta * gradient

        if(abs(J(theta, Xb, y) - J(last_theta, Xb, y)) < epsilon):
            break

        i_iter += 1

    return theta

In [8]:
X_b = np.hstack([np.ones((len(X), 1)), X])
initial_theta = np.zeros(X_b.shape[1])
eta = 0.01

%time theta = gradient_descent(dJ_debug, X_b, y, initial_theta, eta)
theta

CPU times: user 1.91 s, sys: 0 ns, total: 1.91 s
Wall time: 2.06 s


array([ 1.1251597 ,  2.05312521,  2.91522497,  4.11895968,  5.05002117,
        5.90494046,  6.97383745,  8.00088367,  8.86213468,  9.98608331,
       10.90529198])

In [9]:
%time theta = gradient_descent(dJ_math, X_b, y, initial_theta, eta)
theta

CPU times: user 281 ms, sys: 0 ns, total: 281 ms
Wall time: 326 ms


array([ 1.1251597 ,  2.05312521,  2.91522497,  4.11895968,  5.05002117,
        5.90494046,  6.97383745,  8.00088367,  8.86213468,  9.98608331,
       10.90529198])