In [12]:
import torch as t
import einops
from typing import Callable
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from optimizers import *

Here we utilise some of the code from Chapter 0.3 of Arena 3.0 (https://github.com/callummcdougall/ARENA_3.0/tree/main) to investigate the stability of different optimization algorithms to their hyper-parameters.

In [17]:
def plot_fn(fn: Callable,points, x_range=[-6, 6], y_range=[-10, 6], n_points=100, return_fig=False, log_scale=False):
    '''Plot the specified function over the specified domain.

    If log_scale is True, take the logarithm of the output before plotting.
    '''
    x = t.linspace(*x_range, n_points)
    xx = einops.repeat(x, "w -> h w", h=n_points)
    y = t.linspace(*y_range, n_points)
    yy = einops.repeat(y, "h -> h w", w=n_points)

    z = fn(xx, yy)

    fig = make_subplots(
        specs=[[{"type": "scene"}, {"type": "xy"}]],
        rows=1, cols=2,
        subplot_titles=["3D plot", "2D plot" if not log_scale else "2D plot (log scale)"]
    ).update_layout(
        height=500, width=1100, title_font_size=40,
    ).update_annotations(
        font_size=20
    )

    fig.add_trace(
        go.Surface(
            x=x, y=y, z=z,
            showscale=False,
            colorscale="greys",
            hovertemplate = '<b>x</b> = %{x:.2f}<br><b>y</b> = %{y:.2f}<br><b>z</b> = %{z:.2f}</b>',
            contours = dict(
                x = dict(show=True, color="grey", start=x_range[0], end=x_range[1], size=0.2),
                y = dict(show=True, color="grey", start=y_range[0], end=y_range[1], size=0.2),
            )
        ), row=1, col=1
    )
    fig.add_trace(
        go.Heatmap(
            x=x, y=y, z=z if not log_scale else t.log(z),
            showscale=False,
            customdata=z,
            hovertemplate = '<b>x</b> = %{x:.2f}<br><b>y</b> = %{y:.2f}<br><b>z</b> = %{customdata:.2f}</b>',
            colorscale="greys",
        ),
        row=1, col=2
    )
    fig.add_trace(go.Scatter3d(x=points[:,0],y=points[:,1],z=fn(points[:,0],points[:,1]),mode="markers",marker=dict(size=3, color="red"),showlegend=False),row=1,col=1)
    fig.add_trace(go.Scatter(x=points[:,0],y=points[:,1],mode="markers",marker=dict(size=6, color="red"),showlegend=False),row=1,col=2)
    fig.update_scenes(aspectmode="cube")
    fig.show(showlegend=False)

def opt_fn(fn: Callable, xy: t.Tensor, optimizer_class, optimizer_hyperparams: dict, n_iters: int = 100):
    assert xy.requires_grad

    xys = t.zeros((n_iters, 2))
    optimizer = optimizer_class([xy], **optimizer_hyperparams)

    for i in range(n_iters):
        xys[i] = xy.detach()
        out = fn(xy[0], xy[1])
        out.backward()
        optimizer.step()
        optimizer.zero_grad()

    return xys

def rosenbrocks_banana_func(x, y, a=1, b=100):
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1

In [14]:
momentums=np.linspace(0.9,0.99,50)
alphas=np.linspace(0.9,0.99,50)
lrs=np.linspace(0.01,0.1,50)

start_locs=5.5*(2*np.random.rand(10,2)-1)

In [15]:
plot_fn(rosenbrocks_banana_func,points=start_locs)

In particular, we investigate how far the output of the optimisation deviates from it initial position as we vary the hyper-parameter values. 

In [18]:
fig = make_subplots(rows=3,cols=2, subplot_titles=["Momentum (RMSprop)","Momentum (Adam)","Variance Regularisation (RMSprop)","Variance Regularisation (Adam)","Learning Rates (RMSprop)", "Learning Rates (Adam)"])

for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": 0.02, "alpha": 0.99, "momentum": momentums[0]})
  mom_prev=momentums[0]
  for momentum in momentums[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": 0.02, "alpha": 0.99, "momentum": momentum})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(momentum-mom_prev))
    xys_prev=xys
    mom_prev=momentum

  fig.add_scatter(x=momentums,y=rate_of_change_of_divergences,mode="lines",row=1,col=1,showlegend=False)


for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": 0.02, "alpha": alphas[0], "momentum": 0.99})
  alpha_prev=alphas[0]
  for alpha in alphas[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": 0.02, "alpha": alpha, "momentum": 0.99})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(alpha-alpha_prev))
    xys_prev=xys
    alpha_prev=alpha

  fig.add_scatter(x=alphas,y=rate_of_change_of_divergences,mode="lines",row=2,col=1,showlegend=False)

for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": lrs[0], "alpha": 0.99, "momentum": 0.99})
  lr_prev=lrs[0]
  for lr in lrs[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=RMSprop, optimizer_hyperparams={"lr": lr, "alpha": 0.99, "momentum": 0.99})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(lr-lr_prev))
    xys_prev=xys
    lr_prev=lr

  fig.add_scatter(x=lrs,y=rate_of_change_of_divergences,mode="lines",row=3,col=1,showlegend=False)


for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": 0.02, "betas": (momentums[0],0.99)})
  mom_prev=momentums[0]
  for momentum in momentums[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": 0.02, "betas": (momentum,0.99)})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(momentum-mom_prev))
    xys_prev=xys
    mom_prev=momentum

  fig.add_scatter(x=momentums,y=rate_of_change_of_divergences,mode="lines",row=1,col=2,showlegend=False)

for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": 0.02, "betas": (0.99,alphas[0])})
  alpha_prev=alphas[0]
  for alpha in alphas[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": 0.02, "betas": (0.99,alpha)})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(alpha-alpha_prev))
    xys_prev=xys
    alpha_prev=alpha

  fig.add_scatter(x=alphas,y=rate_of_change_of_divergences,mode="lines",row=2,col=2,showlegend=False)

for start_loc in start_locs:
  rate_of_change_of_divergences=[]
  xys_prev=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": lrs[0], "betas": (0.99,0.99)})
  lr_prev=lrs[0]
  for lr in lrs[1:]:
    xys=opt_fn(rosenbrocks_banana_func, xy=t.tensor(start_loc, requires_grad=True), optimizer_class=Adam, optimizer_hyperparams={"lr": lr, "betas": (0.99,0.99)})
    diff=xys[-1]-xys_prev[-1]
    diff_distnace=(diff[0]**2+diff[1]**2)**0.5
    rate_of_change_of_divergences.append(diff_distnace/(lr-lr_prev))
    xys_prev=xys
    lr_prev=lr

  fig.add_scatter(x=lrs,y=rate_of_change_of_divergences,mode="lines",row=3,col=2,showlegend=False)

fig.show()

We see that Adam, in conjunction with provided more desireable convergence characteristics, is more stable under hyper-parameter tuning.