# Update rules

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation

from IPython.display import HTML
from matplotlib import cm
from matplotlib.colors import LogNorm

In [2]:
def sgd(f, df, x0, y0, lr, steps):
    x = np.zeros(steps + 1)
    y = np.zeros(steps + 1)
    x[0] = x0
    y[0] = y0
    
    for i in range(steps):
        (dx, dy) = df(x[i], y[i])
        x[i + 1] = x[i] - lr * dx
        y[i + 1] = y[i] - lr * dy
    
    z = f(x, y)
    return [x, y, z]

In [3]:
def nesterov(f, df, x0, y0, lr, steps, momentum):
    x = np.zeros(steps + 1)
    y = np.zeros(steps + 1)
    x[0] = x0
    y[0] = y0
    dx_v = 0
    dy_v = 0
    
    for i in range(steps):
        (dx_ahead, dy_ahead) = df(x[i] + momentum * dx_v, y[i] + momentum * dy_v)
        dx_v = momentum * dx_v - lr * dx_ahead
        dy_v = momentum * dy_v - lr * dy_ahead
        x[i + 1] = x[i] + dx_v
        y[i + 1] = y[i] + dy_v
    
    z = f(x, y)
    return [x, y, z]

In [4]:
def adagrad(f, df, x0, y0, lr, steps):
    x = np.zeros(steps + 1)
    y = np.zeros(steps + 1)
    x[0] = x0
    y[0] = y0
    dx_cache = 0
    dy_cache = 0
    
    for i in range(steps):
        (dx, dy) = df(x[i], y[i])
        dx_cache += dx ** 2
        dy_cache += dy ** 2
        x[i + 1] = x[i] - lr * dx / (1e-8 + np.sqrt(dx_cache))
        y[i + 1] = y[i] - lr * dy / (1e-8 + np.sqrt(dy_cache))
    
    z = f(x, y)
    return [x, y, z]

In [5]:
def rmsprop(f, df, x0, y0, lr, steps, decay_rate):
    x = np.zeros(steps + 1)
    y = np.zeros(steps + 1)
    x[0] = x0
    y[0] = y0
    dx_cache = 0
    dy_cache = 0
    
    for i in range(steps):
        (dx, dy) = df(x[i], y[i])
        dx_cache = decay_rate * dx_cache + (1 - decay_rate) * dx ** 2
        dy_cache = decay_rate * dy_cache + (1 - decay_rate) * dy ** 2
        x[i + 1] = x[i] - lr * dx / (1e-8 + np.sqrt(dx_cache))
        y[i + 1] = y[i] - lr * dy / (1e-8 + np.sqrt(dy_cache))
    
    z = f(x, y)
    return [x, y, z]

In [6]:
def adam(f, df, x0, y0, lr, steps, beta1, beta2):
    # adam with bias correction
    x = np.zeros(steps + 1)
    y = np.zeros(steps + 1)
    x[0] = x0
    y[0] = y0
    dx_v = 0
    dy_v = 0
    dx_cache = 0
    dy_cache = 0
    
    for i in range(steps):
        (dx, dy) = df(x[i], y[i])
        
        dx_v = beta1 * dx_v + (1 - beta1) * dx
        dx_v_hat = dx_v / (1 - beta1 ** (i + 1))
        
        dx_cache = beta2 * dx_cache + (1 - beta2) * dx ** 2
        dx_cache_hat = dx_cache / (1 - beta2 ** (i + 1))
        
        dy_v = beta1 * dy_v + (1 - beta1) * dy
        dy_v_hat = dy_v / (1 - beta1 ** (i + 1))
        
        dy_cache = beta2 * dy_cache + (1 - beta2) * dy ** 2
        dy_cache_hat = dy_cache / (1 - beta2 ** (i + 1))
        
        x[i + 1] = x[i] - lr * dx_v_hat / (1e-8 + np.sqrt(dx_cache_hat))
        y[i + 1] = y[i] - lr * dy_v_hat / (1e-8 + np.sqrt(dy_cache_hat))
    
    z = f(x, y)
    return [x, y, z]

In [7]:
def update_lines(num, dataLines, lines):
    for line, data in zip(lines, dataLines):
        # NOTE: there is no .set_data() for 3 dim data...
        line.set_data(data[0:2, :num])
        line.set_3d_properties(data[2, :num])
        line.set_marker('o')
        line.set_markevery([-1])
    return lines

def create_and_save_animation(func_title, f, df, params={}, plot_params={}):
    x0 = params.get('x0', 0)
    y0 = params.get('y0', 0)
    lr = params.get('lr', .1)
    steps = params.get('steps', 8)
    momentum = params.get('momentum', .9)
    decay_rate = params.get('decay_rate', .9)
    beta1 = params.get('beta1', .9)
    beta2 = params.get('beta2', .999)
    
    # sgd params
    x0_sgd = params.get('x0_sgd', x0)
    y0_sgd = params.get('y0_sgd', y0)
    lr_sgd = params.get('lr_sgd', lr)
    
    # nesterov params
    x0_nesterov = params.get('x0_nesterov', x0)
    y0_nesterov = params.get('y0_nesterov', y0)
    lr_nesterov = params.get('lr_nesterov', lr)
    
    # adagrad params
    x0_adagrad = params.get('x0_adagrad', x0)
    y0_adagrad = params.get('y0_adagrad', y0)
    lr_adagrad = params.get('lr_adagrad', lr)
    
    # rmsprop params
    x0_rmsprop = params.get('x0_rmsprop', x0)
    y0_rmsprop = params.get('y0_rmsprop', y0)
    lr_rmsprop = params.get('lr_rmsprop', lr)
    
    # adam params
    x0_adam = params.get('x0_adam', x0)
    y0_adam = params.get('y0_adam', y0)
    lr_adam = params.get('lr_adam', lr)
    
    azim = plot_params.get('azim', -29)
    elev = plot_params.get('elev', 49)
    rotation = plot_params.get('rotation', -7)
    
    # attaching 3D axis to the figure
    fig = plt.figure(figsize=(12, 8))
    ax = p3.Axes3D(fig, azim=azim, elev=elev)
    
    # plot the surface
    x = np.arange(-6.5, 6.5, 0.1)
    y = np.arange(-6.5, 6.5, 0.1)
    x, y = np.meshgrid(x, y)
    z = f(x, y)
    ax.plot_surface(x, y, z, rstride=1, cstride=1,
                    norm = LogNorm(), cmap = cm.jet)
    ax.set_title(func_title, rotation=rotation)

    # lines to plot in 3D
    sgd_data = sgd(f, df, x0_sgd, y0_sgd, lr_sgd, steps)
    nesterov_data = nesterov(f, df, x0_nesterov, y0_nesterov, lr_nesterov, steps, momentum)
    adagrad_data = adagrad(f, df, x0_adagrad, y0_adagrad, lr_adagrad, steps)
    rmsprop_data = rmsprop(f, df, x0_rmsprop, y0_rmsprop, lr_rmsprop, steps, decay_rate)
    adam_data = adam(f, df, x0_adam, y0_adam, lr_adam, steps, beta1, beta2)
    data = np.array([sgd_data, nesterov_data, adagrad_data, rmsprop_data, adam_data])

    # NOTE: Can't pass empty arrays into 3d version of plot()
    lines = [ax.plot(dat[0, 0:1], dat[1, 0:1], dat[2, 0:1])[0] for dat in data]
    ax.legend(lines, ['SGD', 'Nesterov Momentum', 'Adagrad', 'RMSProp', 'Adam'])

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    plt.rcParams['animation.html'] = 'html5'

    line_ani = animation.FuncAnimation(fig, update_lines, steps+2, fargs=(data, lines),
                                       interval=500, blit=False, repeat=False)

    plt.close()
    line_ani.save(f'optimization_{func_title}.gif', writer='imagemagick',fps=500/100)
    
    return line_ani

In [8]:
func_title = 'sphere_function'

def f(x, y):
    return x ** 2 + y ** 2

def df(x, y):
    return (2 * x, 2 * y)

create_and_save_animation(func_title, f, df,
                          params={
                              'steps': 15,
                              'lr': .2,
                              'x0_sgd': -4,
                              'y0_sgd': -4,
                              'x0_nesterov': -4.2,
                              'y0_nesterov': -3.8,
                              'x0_adagrad': -4,
                              'y0_adagrad': 4,
                              'x0_rmsprop': -4.2,
                              'y0_rmsprop': 3.8,
                              'x0_adam': -4,
                              'y0_adam': 4.2,
                          },
                          plot_params={
                              'azim': 15,
                              'elev': 60,
                              'rotation': -7
                          })

In [9]:
func_title = 'himmelblau_function'

def f(x, y):
    return (x ** 2 + y - 11) ** 2 + (x + y ** 2 - 7) ** 2

def df(x, y):
    return (4 * x * (x ** 2 + y - 11) + 2 * (x + y ** 2 - 7),
            2 * (x ** 2 + y - 11) + 4 * y * (x + y ** 2 - 7))

create_and_save_animation(func_title, f, df,
                          params={
                              'steps': 25,
                              'lr': .005,
                              'x0': 0,
                              'y0': -3,
                              'lr_adagrad': .5,
                              'lr_rmsprop': .5,
                              'lr_adam': .5
                          },
                          plot_params={
                              'azim': -29,
                              'elev': 70,
                              'rotation': 17
                          })