In [1]:
import numpy as np

In [2]:
def f(x: np.ndarray) -> float:
    """Toy example function that is convex and has a global minimum at (0,0)"""
    return 0.5*(x[0]**2 + 10*x[1]**2)

In [3]:
def gradient(x: np.ndarray) -> np.ndarray:
    return np.array([x[0], 10*x[1]])

In [6]:
def gradient_descent(x0, num_iters, alpha, epsilon=1e-6):
    x = x0
    for i in range(num_iters):
        x_new = x - alpha * gradient(x)

        if np.abs(f(x_new) - f(x)) < epsilon:
            break

        x = x_new

    return x

In [8]:
x0 = np.array([8, 7])
gradient_descent(x0=x0, num_iters=1000, alpha=0.01)

array([1.00113903e-02, 2.60835571e-30])

In [9]:
def momentum(x0, num_iters, alpha, beta, epsilon=1e-6):
    x = x0
    inertia = 0
    for i in range(num_iters):
        inertia = beta * inertia + alpha * gradient(x)
        x_new = x - inertia

        if np.abs(f(x_new) - f(x)) < epsilon:
            break

        x = x_new

    return x

In [11]:
momentum(x0=x0, num_iters=1000, alpha=0.01, beta=0.5)

array([ 6.89326119e-03, -3.34610417e-51])

In [12]:
def nesterov(x0, num_iters, alpha, beta, epsilon=1e-6):
    x = x0
    inertia = 0
    for i in range(num_iters):
        inertia = beta * inertia + alpha * gradient(x - inertia)
        x_new = x - inertia

        if np.abs(f(x_new) - f(x)) < epsilon:
            break

        x = x_new

    return x

In [13]:
nesterov(x0=x0, num_iters=1000, alpha=0.01, beta=0.5)

array([7.07856536e-03, 1.40748001e-33])

In [21]:
def adam(x0, num_iters, alpha, beta1, beta2, epsilon=1e-6, delta=1e-7):
    x = x0
    m = 0
    v = 0
    for i in range(1, num_iters + 1):
        g = gradient(x)
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * g**2

        m_hat = m / (1 - beta1 ** i)
        v_hat = v / (1 - beta2 ** i)

        x_new = x - alpha * m_hat / (np.sqrt(v_hat) + delta)

        if np.abs(f(x_new) - f(x)) < epsilon:
            break

        x = x_new

    return x

In [25]:
adam(x0=x0, num_iters=100000, alpha=0.01, beta1=0.9, beta2=0.999)

array([0.01127514, 0.00198679])