In [5]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets

In [2]:
def gradient_descent(grad, lr, initial_position):
    N = 10
    x = [initial_position]
    for i in range(N):
        if np.abs(x[-1]) > 5:
            x = x[:-1]
            break
        x.append(x[-1] - lr*grad(x[-1]))
    return np.array(x)

def f(x):
    return x**4 - 3*x**3 + x**2 + x + 1

def grad_f(x):
    return 4*x**3 - 9*x**2 + 2*x + 1

def g(x):
    return x**2

def grad_g(x):
    return 2*x

In [6]:
fig, ax = plt.subplots()
slider = widgets.FloatSlider(value=0.0, min=0.0, max=1.5, step=0.001)
output = widgets.Output(layour={'border': '1px solid black'})

@output.capture(clear_output=True, wait=True)
def on_value_change(lr):
    xs = np.linspace(-4, 4, 200)
    x_gd = gradient_descent(grad_g, lr, -2)
    ax.clear()
    ax.plot(xs, g(xs))
    ax.plot(x_gd, g(x_gd), marker='o')
    ax.set_xlim(-4, 4)
    ax.set_ylim(-0.5, 16)
    display(ax.figure)

widgets.interactive(on_value_change, lr=slider)
display(slider, output)
plt.close()


FloatSlider(value=0.0, description='lr', max=1.5, step=0.001)

Output()

In [7]:
fig, ax = plt.subplots()
slider = widgets.FloatSlider(value=0.0, min=0.0, max=0.17, step=0.001)
output = widgets.Output(layour={'border': '1px solid black'})

@output.capture(clear_output=True, wait=True)
def on_value_change(lr):
    xs = np.linspace(-4, 4, 200)
    x_gd = gradient_descent(grad_g, lr, -2)
    xs = np.linspace(-1.5, 3, 200)
    x_gd = gradient_descent(grad_f, lr, -1.3)
    ax.clear()
    ax.plot(xs, f(xs))
    ax.plot(x_gd, f(x_gd), marker='o')
    ax.set_xlim(-2, 3)
    ax.set_ylim(-2, 18)
    display(ax.figure)

widgets.interactive(on_value_change, lr=slider)
display(slider, output)
plt.close()

FloatSlider(value=0.0, description='lr', max=0.17, step=0.001)

Output()