# Random Cones

This notebook implements gradient agreement filtering (GAF) for random cones. The objective of the optimization is to find the minimum of a true cone. However, the optimizer can only sample from cones whose center point is randomly perturbed from the true cone. This simulates the noisiness and randomness of real-world optimization problems.

In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
# Load Packages
import jax
import jax.numpy as jnp
import optax
from plotly.subplots import make_subplots

# Load Custom Modules
from plots import plot_true_function, add_optimization_path
from random_cones import optimize_random_cone, initial_point, cone
from core import rand_int

In [7]:
# Define Constants
X_MIN, X_MAX, X_POINTS = -10, 10, 100
Y_MIN, Y_MAX, Y_POINTS = -10, 10, 100

x = jnp.linspace(X_MIN, X_MAX, X_POINTS)
y = jnp.linspace(Y_MIN, Y_MAX, Y_POINTS)

# Randomly sample initial point
pi = initial_point(42, minval=-10, maxval=10)

true_params = (0, 0, 1)
noise_params = (5, 5, 1)

In [11]:
# Optimize Cone using SGD from optax
p = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=0.25, batch_size=3, method='sgd')

fig = plot_true_function(x, y, lambda x: cone(x, true_params))
error = 0
fig.update_layout(title=f'SGD<br><sup>Error: {jnp.linalg.norm(p[-1]):.3f}</sup>', title_x=0.5)
add_optimization_path(fig, p)
fig.show()

In [12]:
p = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=0.2, batch_size=10, method='adam')

fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'Adam<br><sup>Error: {jnp.linalg.norm(p[-1]):.3f}</sup>', title_x=0.5)
add_optimization_path(fig, p)
fig.show()

In [13]:
p = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=1.0, batch_size=10, method='adagrad')

fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'AdaGrad<br><sup>Error: {jnp.linalg.norm(p[-1]):.3f}</sup>', title_x=0.5)
add_optimization_path(fig, p)
fig.show()

In [7]:
# p = optimize_random_cone(pi, true_params, noise_params, steps=10, learning_rate=0.1, method='lbfgs')
# 
# fig = plot_true_function(x, y, lambda x: cone(x, true_params))
# error, _ = compute_errors(p[-1], true_params)
# fig.update_layout(title=f'L-BFGS<br><sup>Error: {error:.3f}</sup>', title_x=0.5)
# add_optimization_path(fig, p)
# fig.show()

In [14]:
cfig = make_subplots(rows=2, cols=2, subplot_titles=['SGD', 'Adam', 'AdaGrad', 'GAF'])

# SGD
p_sgd = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=0.25, batch_size=3, method='sgd')
fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'SGD', title_x=0.5)
add_optimization_path(fig, p_sgd)
cfig.add_trace(fig.data[0], row=1, col=1)
cfig.add_trace(fig.data[1], row=1, col=1)

# Adam
p_adam = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=0.2, batch_size=10, method='adam')
fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'Adam', title_x=0.5)
add_optimization_path(fig, p_adam)
cfig.add_trace(fig.data[0], row=1, col=2)
cfig.add_trace(fig.data[1], row=1, col=2)

# AdaGrad
p_adagrad = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=1.0, batch_size=10, method='adagrad')
fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'AdaGrad', title_x=0.5)
add_optimization_path(fig, p_adagrad)
cfig.add_trace(fig.data[0], row=2, col=1)
cfig.add_trace(fig.data[1], row=2, col=1)

# GAF
p_gaf = optimize_random_cone(pi, true_params, noise_params, steps=50, learning_rate=0.25, batch_size=3, method='sgd')
fig = plot_true_function(x, y, lambda x: cone(x, true_params))
fig.update_layout(title=f'GAF', title_x=0.5)
add_optimization_path(fig, p_gaf)
cfig.add_trace(fig.data[0], row=2, col=2)
cfig.add_trace(fig.data[1], row=2, col=2)

cfig.update_layout(height=1000, width=1000, title_text='Optimization Algorithms', title_x=0.5)
cfig.show()
