In [1]:
import tensorflow as tf
import numpy  as np


data = [] # 保存样本集的列表
for i in range(100):
    x = np.random.uniform(-10., 10.) # 随机采样输入 x
    eps = np.random.normal(0., 0.01) # 采样高斯噪声
    y = 1.477 * x + 0.089 + eps # 得到模型输出
    data.append([x, y]) # 保存样本点
data = np.array(data) # 转换为2D Numpy数组

In [3]:
print(eps)

-0.004151078948789324


In [5]:
# 计算误差
def mse(b, w, points):
    totalError = 0
    for i in range(0, len(points)): 
        x = points[i, 0] # 获取i号点的输入x
        y = points[i, 1] # 获取i号点的输出y
        totalError += (y - (w * x + b)) ** 2 # 计算误差，并累加
    return totalError / float(len(points))  # 将累加的误差求平均，得到均方差

In [6]:
# 计算梯度
def step_gradient(b_current, w_current, points, lr):
    b_gradient = 0
    w_gradient = 0
    M = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        # 误差函数对b的导数
        b_gradient += (2/M) * ((w_current * x + b_current) - y)
        # 误差函数对w的导数
        w_gradient += (2/M) * x * ((w_current * x + b_current) - y)
    # 根据梯度下降算法更新 w', b', 其中lr为学习率
    new_b = b_current - (lr * b_gradient)
    new_w = w_current - (lr * w_gradient)
    return [new_b, new_w]

In [7]:
# 梯度更新
def gradient_descent(points, starting_b, starting_w, lr, num_iterations):
    b = starting_b
    w = starting_w
    # 根据梯度下降算法更新多次
    for step in range(num_iterations):
        b, w = step_gradient(b, w, np.array(points), lr)
        loss = mse(b, w, points) # 计算当前的均方差，用于监控训练进度
        if step%50 == 0: # 打印误差和实时的w, b 值
            print(f"iteration:{step}, loss:{loss}, w:{w}, b{b}")
    return [b, w]  # 返回最后一次的w, b

In [10]:
# 主训练函数
def main():
    # 加载训练集数据，这些数据是通过真实模型添加观测误差采样得到的
    lr = 0.01 # 学习率
    initial_b = 0 # 初始化b
    initial_w = 0 # 初始化w
    num_iterations = 1000
    # 训练优化1000次，返回最优 w*, b* 和训练loss的下降过程
    [b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
    loss = mse(b, w, data) # 计算最优数值解w, b上的均方差
    print(f'Final_loss:{loss}, w:{w}, b:{b}')

In [11]:
main()

iteration:0, loss:8.324909730150774, w:0.9763694560043762, b0.01808096884041486
iteration:50, loss:0.0006655704135862034, w:1.477553902277715, b0.06697030494724514
iteration:100, loss:0.00019594194645215811, w:1.4772992911148601, b0.0817730657056694
iteration:150, loss:0.00013244416262746898, w:1.47720566876936, b0.08721614670235554
iteration:200, loss:0.00012385871938283327, w:1.4771712431633548, b0.0892176065124538
iteration:250, loss:0.00012269789405528438, w:1.4771585846218231, b0.08995355759859
iteration:300, loss:0.00012254094054657384, w:1.477153929985571, b0.09022417207630892
iteration:350, loss:0.00012251971909033057, w:1.4771522184425323, b0.09032367895513638
iteration:400, loss:0.00012251684976801065, w:1.4771515890959162, b0.0903602683461306
iteration:450, loss:0.00012251646181110325, w:1.4771513576806647, b0.09037372252687785
iteration:500, loss:0.00012251640935601258, w:1.4771512725876217, b0.0903786697253597
iteration:550, loss:0.00012251640226363988, w:1.477151241298301