# Setup

Dependencies:
- System: python3
- Python: jupyter, numpy, matplotlib, jax (for autodifferentiation)

Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):
```
/usr/bin/python3 -m pip install --upgrade pip
pip install --upgrade jupyter numpy matplotlib jax jaxlib
jupyter notebook  # from the directory of this notebook
```
Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/AA203-Examples/blob/master/Lecture-3/Unconstrained%20Optimization.ipynb).

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
plt.jet()
%matplotlib notebook

In [None]:
# Plotting code.


def backtracking_line_search(f, x0, p, k_max=10):
    # Normally this would be implemented with a `while` loop/appropriate exit conditions.
    xs = x0 + 2.0**(-np.arange(k_max))[:, np.newaxis] * p
    return xs[jnp.argmin(jax.vmap(f)(xs))]


def gradient_descent(f, descent_direction, x0, num_steps, constant_stepsize=None):
    xs = [x0]
    # Would be better implemented with `jax.lax.scan`.
    for _ in range(num_steps):
        x = xs[-1]
        p = descent_direction(x)
        if constant_stepsize is not None:
            xs.append(x + p * constant_stepsize)
        else:
            xs.append(backtracking_line_search(f, x, p))
    return jnp.array(xs)


class InteractiveContourPlot:

    def __init__(self, f, x_range, y_range, *args, **kwargs):
        self.f = f
        self.X, self.Y = np.meshgrid(np.linspace(*x_range, 100), np.linspace(*y_range, 100), indexing='ij')

        self.fig = plt.figure(*args, **kwargs)
        self.ax = self.fig.add_subplot(111)
        self.ax.contour(self.X, self.Y, jax.vmap(jax.vmap(f))(np.stack([self.X, self.Y], -1)), levels=50)

        self.active = False
        self.last_event = None
        self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidpress = self.fig.canvas.mpl_connect('motion_notify_event', self.on_move)

    def redraw(self):
        raise NotImplementedError

    def on_press(self, event):
        self.event = event
        self.active = not self.active
        self.redraw()

    def on_move(self, event):
        self.event = event
        self.redraw()


class DescentDirectionsPlot(InteractiveContourPlot):

    def __init__(self, f, x_range, y_range, *args, **kwargs):
        super().__init__(f, x_range, y_range, *args, **kwargs)
        self.neg_grad_f = jax.jit(lambda x: -jax.grad(f)(x))
        self.newton_f = jax.jit(lambda x: -jnp.linalg.solve(jax.hessian(f)(x), jax.grad(f)(x)))
        self.descent_directions_plot = None

    def redraw(self):
        if not self.active:
            return
        x = np.array([self.event.xdata, self.event.ydata])
        neg_grad_f_x = self.neg_grad_f(x)
        newton_f_x = self.newton_f(x)
        if self.descent_directions_plot:
            self.descent_directions_plot.set_offsets([x, x])
            self.descent_directions_plot.set_UVC([neg_grad_f_x[0], newton_f_x[0]], [neg_grad_f_x[1], newton_f_x[1]])
        else:
            self.descent_directions_plot = self.ax.quiver([x[0], x[0]], [x[1], x[1]], [neg_grad_f_x[0], newton_f_x[0]],
                                                          [neg_grad_f_x[1], newton_f_x[1]],
                                                          angles='xy',
                                                          scale_units='xy',
                                                          scale=1,
                                                          width=0.005,
                                                          color=['black', 'red'])


class GradientDescentPlot(InteractiveContourPlot):

    def __init__(self,
                 f,
                 x_range,
                 y_range,
                 descent_direction='grad',
                 num_steps=10,
                 constant_stepsize=None,
                 *args,
                 **kwargs):
        super().__init__(f, x_range, y_range, *args, **kwargs)
        self.color = 'black'
        if descent_direction == 'grad':

            def descent_direction_fn(x):
                grad_f_x = jax.grad(f)(x)
                # Normalize the gradient; in general this is optional (and maybe irrelevant depending
                # on line search details).
                return -grad_f_x / jnp.linalg.norm(grad_f_x)

        elif descent_direction == 'newton':
            self.color = 'red'

            def descent_direction_fn(x):
                hess_f_x = jax.hessian(f)(x)
                # Ensure the hessian is positive definite by adding an appropriate multiple of the
                # identity matrix (if necessary). This ensures the returned direction is indeed a descent
                # direction (i.e., `np.dot(direction, grad_f_x) < 0`).
                min_hess_f_x_eigenvalue = jnp.linalg.eigvalsh(hess_f_x)[0]
                pos_def_hess_f_x = hess_f_x + jnp.where(min_hess_f_x_eigenvalue < 1e-3,
                                                        (-min_hess_f_x_eigenvalue + 1e-3) * jnp.eye(2), 0.)
                return -jnp.linalg.solve(pos_def_hess_f_x, jax.grad(f)(x))

        self.gradient_descent = jax.jit(
            lambda x0: gradient_descent(f, descent_direction_fn, x0, num_steps, constant_stepsize))
        self.descent_trace_plot = None

    def redraw(self):
        if not self.active:
            return
        x0 = np.array([self.event.xdata, self.event.ydata])
        xs = self.gradient_descent(x0)
        diffs = np.diff(xs, axis=0)
        if self.descent_trace_plot:
            self.descent_trace_plot.set_offsets(xs[:-1])
            self.descent_trace_plot.set_UVC(diffs[:, 0], diffs[:, 1])
        else:
            self.descent_trace_plot = self.ax.quiver(xs[:-1, 0],
                                                     xs[:-1, 1],
                                                     diffs[:, 0],
                                                     diffs[:, 1],
                                                     angles='xy',
                                                     scale_units='xy',
                                                     scale=1,
                                                     width=0.005,
                                                     color=self.color)

In [None]:
# Defining some interesting functions.


def gaussian_mixture_model_logpdf(x, weights, means, covariances):
    log_probs = jax.scipy.stats.multivariate_normal.logpdf(x, means, covariances)
    return jax.scipy.special.logsumexp(log_probs, b=weights)


def two_local_minima(x):
    # Makes for an interesting-looking optimization landscape; details unimportant.
    np.random.seed(0)
    cov_factor = np.eye(2)[np.newaxis] * np.array([.3, .4, .6])[:, np.newaxis,
                                                                np.newaxis] + (np.random.rand(3, 2, 2) - 0.5) / 2

    return -gaussian_mixture_model_logpdf(x, np.array([0.1, 0.2, 0.7]), np.array([[-.7, .7], [0., .6], [.8, -.8]]),
                                          np.matmul(cov_factor, cov_factor.swapaxes(-1, -2))) / 10


def rosenbrock(x):
    return ((1 - x[0])**2 + 4 * (x[1] - x[0]**2)**2) / 10

In [None]:
# Click to activate/deactivate descent direction visualization (you may have to wait a bit for JAX compilation).
# Black: gradient
# Red: direction from Newton's method (only guaranteed to be a descent direction if the Hessian is positive
# definite; visualized in all cases anyway)
DescentDirectionsPlot(
    two_local_minima,  # try also `rosenbrock`
    (-1.5, 1.5),
    (-1.5, 1.5),
    figsize=(8, 8))

In [None]:
# Click to activate/deactivate gradient descent visualization (you may have to wait a bit for JAX compilation).
GradientDescentPlot(
    two_local_minima,  # try also `rosenbrock`
    (-1.5, 1.5),
    (-1.5, 1.5),
    'grad',  # try 'grad' for steepest descent, 'newton' for Newton's method
    num_steps=15,
    constant_stepsize=None,  # `None` for backtracking line search, otherwise try, e.g., `0.1`
    figsize=(8, 8))