In [None]:
''''''''''''''''''''
'''    梯度下降    '''      
''''''''''''''''''''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# 自定义损失函数
def loss_function(x, y):
    return (x - 2) ** 2 + (y - 3) ** 2

# 对自定义损失函数的手动梯度计算
def gradient(x, y):
    partial_x = 2 * (x - 2)
    partial_y = 2 * (y - 3)
    return partial_x, partial_y

# 梯度下降算法
def gradient_descent(x, y, learning_rate=0.1, num_iterations=50):
    trajectory = [(x, y, loss_function(x, y))]
    for _ in range(num_iterations):
        grad_x, grad_y = gradient(x, y)
        x -= learning_rate * grad_x
        y -= learning_rate * grad_y
        trajectory.append((x, y, loss_function(x, y)))
    return np.array(trajectory)

# 参数初始化
x, y =  0, 0    # 初始点
lr = 0.1        # 学习率
num_iters = 50  # 迭代次数
trajectory = gradient_descent(x, y, learning_rate=lr, num_iterations=num_iters)

# 生成网格
x_vals = np.linspace(-1, 5, 400)
y_vals = np.linspace(0, 6, 400)
X, Y = np.meshgrid(x_vals, y_vals)
Z = loss_function(X, Y)

# 创建3D图形
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7)

# 设置标签和标题
ax.set_xlabel("X axis")
ax.set_ylabel("Y axis")
ax.set_zlabel("Loss")
ax.set_title("Gradient Descent on Loss Function")

# 设置视角
ax.view_init(elev=30, azim=120)
# 初始化轨迹线和点
(trajectory_line,) = ax.plot([], [], [], color="red", linewidth=2, label="Trajectory")
(point,) = ax.plot([], [], [], "ro")  # 当前点

# 更新动画的函数
def update(frame):
    trajectory_line.set_data(trajectory[: frame + 1, 0], trajectory[: frame + 1, 1])
    trajectory_line.set_3d_properties(trajectory[: frame + 1, 2])

    # 更新当前点
    point.set_data([trajectory[frame, 0]], [trajectory[frame, 1]])
    point.set_3d_properties([trajectory[frame, 2]])

    # 更新视角
    ax.view_init(elev=30, azim=120 - frame * 2) # 旋转视角

    return trajectory_line, point

'''
trajectory_line.set_data(trajectory[:, 0], trajectory[:, 1])
trajectory_line.set_3d_properties(trajectory[:, 2])
point.set_data([trajectory[-1, 0]], [trajectory[-1, 1]])
point.set_3d_properties([trajectory[-1, 2]])
# 更新视角
ax.view_init(elev=30, azim=120)
'''

# 创建动画
ani = animation.FuncAnimation(fig, update, frames=len(trajectory), interval=200, blit=False)
ax.legend()
plt.show()  
# 理论上应该显示一个旋转的动画，但是在jupyter notebook中无法显示
# 取消ani赋值会弹出warning信息，猜测是因为cell运行结束导致所有进程关闭，使得动画无法持续渲染

