In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import time

In [None]:
# Define the function that returns the gradient of the polynomial
def gradient(weight, X, Y):
    # Initialize the gradient
    grad = np.zeros(len(weight))
    # Compute the gradient
    for i in range(len(weight)):
        grad[i] = np.sum((weight.dot(X) - Y) * X[i])
    return grad

In [None]:
# Define funtion that plot loss function
def plot_loss(loss_history):
    plt.plot(loss_history)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.show()

In [None]:
# 2D polynomial gradient descent
def gradient_descent(gradient, X, Y, initial, learning_rate=0.001, max_iter=100, stop_tolerance=1e-3, online_loss_plot=True):
    """
    Performs gradient descent on a 2D polynomial
    :param gradient: function that returns the gradient of the polynomial
    :param X: X values of the polynomial
    :param Y: Y values of the polynomial
    :param initial: initial guess for the polynomial
    :param learning_rate: learning rate for the gradient descent
    :param max_iter: maximum number of iterations
    :param stop_tolerance: tolerance for stopping the gradient descent
    :param online_loss_plot: if true, plot the loss function online
    :return: the final guess for the polynomial
    """
    # Initialize the guess
    weight = initial
    # Initialize the loss
    loss = np.sum((Y - np.polyval(weight, X)) ** 2)
    # Initialize the loss history
    loss_history = [loss]
    # Initialize the guess history
    weight_history = [weight]
    # Initialize the iteration counter
    # Perform gradient descent
    for iter in range(max_iter):
        # Compute the gradient
        grad = gradient(weight, X, Y)
        # Update the guess
        weight = weight - learning_rate * grad
        # Update the loss
        loss = np.sum((Y - np.polyval(weight, X)) ** 2)
        # Update the loss history
        loss_history.append(loss)
        # Update the guess history
        weight_history.append(weight)
        # Check if the loss has converged
        if np.abs(loss_history[-1] - loss_history[-2]) < stop_tolerance:
            break
        # Plot the loss function
        if online_loss_plot:
            time.sleep(0.01)
            plot_loss(loss_history)
    # Return the final guess
    return weight_history[-1]