# Setup

Dependencies:
- System: python3
- Python: jupyter, numpy, matplotlib, jax (for autodifferentiation)

Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):
```
/usr/bin/python3 -m pip install --upgrade pip
pip install --upgrade jupyter numpy matplotlib jax jaxlib
jupyter notebook  # from the directory of this notebook
```
Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/AA203-Examples/blob/master/Lecture-3/Constrained%20Optimization.ipynb).

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
plt.jet()
%matplotlib notebook

In [None]:
# Plotting code.


def backtracking_line_search(f, x0, p, k_max=10):
    # Normally this would be implemented with a `while` loop/appropriate exit conditions.
    xs = x0 + 2.0**(-np.arange(-2, k_max))[:, np.newaxis] * p
    return xs[jnp.argmin(jax.vmap(f)(xs))]


def projected_gradient_descent(f, project, x0, num_steps):

    #     # The following code is equivalent to:
    #     xs = [x0, project(x0)]
    #     # Would be better implemented with `jax.lax.scan`.
    #     for _ in range(num_steps):
    #         x = xs[-1]
    #         xs.append(backtracking_line_search(lambda x: f(project(x)), x, -jax.grad(f)(x)))
    #         xs.append(project(xs[-1]))
    #     return jnp.array(xs)

    def scan_f(x0, i):
        x1 = backtracking_line_search(lambda x: f(project(x)), x0, -jax.grad(f)(x0))
        x2 = project(x1)
        return x2, jnp.array([x1, x2])

    return jnp.concatenate(
        [jnp.array([x0, project(x0)]),
         jnp.reshape(jax.lax.scan(scan_f, project(x0), None, num_steps)[1], (-1, 2))], 0)


def unconstrained_optimization(f, x0, num_steps=20):
    # This should actually be a full `jax.lax.while_loop` with an appropriate termination criterion.
    return jax.lax.fori_loop(0, num_steps, lambda i, x: backtracking_line_search(f, x, -jax.grad(f)(x)), x0)


def penalty_method(f, constraint, x0, num_steps):

    def scan_f(x0, i):
        x = unconstrained_optimization(lambda y: f(y) + 2**i * constraint(y)**2, x0)
        return x, x

    return jnp.concatenate([x0[np.newaxis], jax.lax.scan(scan_f, x0, jnp.arange(num_steps))[1]], 0)


class InteractiveContourPlot:

    def __init__(self, f, x_range, y_range, *args, **kwargs):
        self.f = f
        self.X, self.Y = np.meshgrid(np.linspace(*x_range, 100), np.linspace(*y_range, 100), indexing='ij')

        self.fig = plt.figure(*args, **kwargs)
        self.ax = self.fig.add_subplot(111)
        self.ax.contour(self.X, self.Y, jax.vmap(jax.vmap(f))(np.stack([self.X, self.Y], -1)), levels=50)

        self.active = False
        self.last_event = None
        self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidpress = self.fig.canvas.mpl_connect('motion_notify_event', self.on_move)

    def redraw(self):
        raise NotImplementedError

    def on_press(self, event):
        self.event = event
        self.active = not self.active
        self.redraw()

    def on_move(self, event):
        self.event = event
        self.redraw()


class ProjectedGradientDescentPlot(InteractiveContourPlot):

    def __init__(self, f, x_range, y_range, num_steps=10, *args, **kwargs):
        super().__init__(f, x_range, y_range, *args, **kwargs)
        self.radius = 0.75
        self.ax.plot(self.radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
                     self.radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
                     linewidth=2,
                     color='black')
        self.project = lambda x: self.radius * x / jnp.linalg.norm(x)
        self.projected_gradient_descent = jax.jit(lambda x0: projected_gradient_descent(f, self.project, x0, num_steps))
        self.descent_trace_plot = None

    def redraw(self):
        if not self.active:
            return
        x0 = np.array([self.event.xdata, self.event.ydata])
        xs = self.projected_gradient_descent(x0)
        diffs = np.diff(xs, axis=0)
        if self.descent_trace_plot:
            self.descent_trace_plot.set_offsets(xs[:-1])
            self.descent_trace_plot.set_UVC(diffs[:, 0], diffs[:, 1])
        else:
            self.descent_trace_plot = self.ax.quiver(xs[:-1, 0],
                                                     xs[:-1, 1],
                                                     diffs[:, 0],
                                                     diffs[:, 1],
                                                     angles='xy',
                                                     scale_units='xy',
                                                     scale=1,
                                                     width=0.005,
                                                     color=['red', 'black'])


class PenaltyMethodPlot(InteractiveContourPlot):

    def __init__(self, f, x_range, y_range, num_steps=10, *args, **kwargs):
        super().__init__(f, x_range, y_range, *args, **kwargs)
        self.radius = 0.75
        self.ax.plot(self.radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
                     self.radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
                     linewidth=2,
                     color='black')
        self.constraint = lambda x: (x[0]**2 + x[1]**2 - self.radius**2) / 10
        self.penalty_method = jax.jit(lambda x0: penalty_method(f, self.constraint, x0, num_steps))
        self.descent_trace_plot = None

    def redraw(self):
        if not self.active:
            return
        x0 = np.array([self.event.xdata, self.event.ydata])
        xs = self.penalty_method(x0)
        diffs = np.diff(xs, axis=0)
        if self.descent_trace_plot:
            self.descent_trace_plot.set_offsets(xs[:-1])
            self.descent_trace_plot.set_UVC(diffs[:, 0], diffs[:, 1])
        else:
            self.descent_trace_plot = self.ax.quiver(xs[:-1, 0],
                                                     xs[:-1, 1],
                                                     diffs[:, 0],
                                                     diffs[:, 1],
                                                     angles='xy',
                                                     scale_units='xy',
                                                     scale=1,
                                                     width=0.005,
                                                     color='red')

In [None]:
# Defining some interesting functions.


def gaussian_mixture_model_logpdf(x, weights, means, covariances):
    log_probs = jax.scipy.stats.multivariate_normal.logpdf(x, means, covariances)
    return jax.scipy.special.logsumexp(log_probs, b=weights)


def two_local_minima(x):
    # Makes for an interesting-looking optimization landscape; details unimportant.
    np.random.seed(0)
    cov_factor = np.eye(2)[np.newaxis] * np.array([.3, .4, .6])[:, np.newaxis,
                                                                np.newaxis] + (np.random.rand(3, 2, 2) - 0.5) / 2

    return -gaussian_mixture_model_logpdf(x, np.array([0.1, 0.2, 0.7]), np.array([[-.7, .7], [0., .6], [.8, -.8]]),
                                          np.matmul(cov_factor, cov_factor.swapaxes(-1, -2))) / 10


def rosenbrock(x):
    return ((1 - x[0])**2 + 4 * (x[1] - x[0]**2)**2) / 10

In [None]:
# Click to activate/deactivate projected gradient descent visualization (you may have to wait a bit for JAX compilation).
# Red: projection step
# Black: gradient step
ProjectedGradientDescentPlot(
    two_local_minima,  # try also `rosenbrock`
    (-1.5, 1.5),
    (-1.5, 1.5),
    figsize=(8, 8))

In [None]:
# Click to activate/deactivate penalty method visualization (you may have to wait a bit for JAX compilation).
p = PenaltyMethodPlot(
    two_local_minima,  # try also `rosenbrock`
    (-1.5, 1.5),
    (-1.5, 1.5),
    figsize=(8, 8))