In [16]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow
from matplotlib.animation import FuncAnimation, PillowWriter

# Function: f(x) = x^2
def f(x):
    return x**2

# Derivative of the function: f'(x) = 2x
def grad_f(x):
    return 2 * x
    
def gradient_descent(learning_rate=0.1, x_start = 3.0, num_steps = 25):
    # Store x values for the steps of gradient descent
    x_values = [x_start]
    for _ in range(num_steps):
        x_next = x_values[-1] - learning_rate * grad_f(x_values[-1])
        x_values.append(x_next)

    # Create a transparent GIF
    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
    fig.patch.set_alpha(0)  # Transparent background
    ax.set_facecolor("none")

    # Plot the function
    x_range = np.linspace(-3.5, 3.5, 500)
    ax.plot(x_range, f(x_range), label=r"$f(x) = x^2$", color="blue", zorder=1)

    # Set limits and labels
    ax.set_xlim(-3.5, 3.5)
    ax.set_ylim(-0.5, 10)
    ax.set_xlabel("w", fontsize=14)
    ax.set_ylabel("E=x\u00b2", fontsize=14)

    # Function to update frames
    points, = ax.plot([], [], "o", color="red", zorder=2)
    arrow = None

    def update(frame):
        nonlocal arrow
        if arrow:
            arrow.remove()
        points.set_data(x_values[:frame + 1], [f(x) for x in x_values[:frame + 1]])

        if frame > 0:
            x_prev, x_curr = x_values[frame - 1], x_values[frame]
            y_prev, y_curr = f(x_prev), f(x_curr)
            dx, dy = x_curr - x_prev, y_curr - y_prev
            arrow = FancyArrow(
                x_prev,
                y_prev,
                dx,
                dy,
                width=0.05,
                color="green",
                length_includes_head=True,
                zorder=3,
            )
            ax.add_patch(arrow)
        return points, arrow

    # Create animation
    ani = FuncAnimation(
        fig, update, frames=len(x_values), interval=200, blit=False
    )

    # Save the animation as a transparent GIF
    ani.save("gradient_descent_demo.gif", writer=PillowWriter(fps=5), savefig_kwargs={"transparent": True})

    plt.close(fig)
    print("GIF saved as 'gradient_descent_demo.gif'")

# Run the function
gradient_descent()


GIF saved as 'gradient_descent_demo.gif'
