## Optimization Dynamics: an Interactive Tool

In this notebook we present a toy constrained optimization problem in 2D.

We provide an interactive widget which shows the optimization path realized when
solving the problem using [Cooper](https://github.com/gallego-posada/cooper).

> Acknowledgement: The presented visualizations and optimization problems follow closely the blogposts by Degrave and Korshunova (2021a, 2021b):


## Table of Contents:
* [Setup](#setup)
* [Constrained Minimization Problem](#cmp)
* [Widget](#widget)
* [References](#references)

### Setup <a class="anchor" id="setup"></a>
Install Cooper, with `examples` requirements.

In [2]:
!pip install -e git+https://github.com/gallego-posada/cooper#egg=.[examples] 

In [3]:
import copy

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import style_utils
import torch
import tutorial_utils
from IPython.display import display
from ipywidgets import HBox, Layout, VBox, interactive
from matplotlib.gridspec import GridSpec

import cooper

%matplotlib inline

torch.manual_seed(0)
np.random.seed(0)

## Constrained Minimization Problem <a class="anchor" id="cmp"></a>

Consider the following constrained optimization problem on
 the 2D domain $(x, y) \in [0,\pi/2] \times [0,\infty]$

$$\begin{align*}
\underset{x, y}{\text{min}}\quad f(x,y) &:= \left(1 - \text{sin}(x) \right) \ \big(1+(y - 1)^2\big) & \tag{1} \\
s.t. \quad  g(x,y) &:= \left(1 - \text{cos}(x) \right)\ \big(1+(y-1)^2\big) - \epsilon \leq 0 & \\
\end{align*}$$

given some $\epsilon \geq 0 $.
Note how both $f$ and $g$ are convex functions in the specified domain.
As such, this constrained minimization problem is a convex problem.

The following class implements this CMP:

In [4]:
class Toy2DCMP(cooper.ConstrainedMinimizationProblem):
    def __init__(self, is_constrained=False, epsilon=1.0, problem_type="Convex"):
        self.problem_type = problem_type
        self.epsilon = epsilon
        super().__init__(is_constrained)

    def closure(self, params):
        """This function evaluates the objective function and constraint
        defect. It updates the attributes of this CMP based on the results."""

        x, y = params[:, 0], params[:, 1]

        if self.problem_type == "Convex":
            f = (1 - torch.sin(x)) * (1 + (y - 1.) ** 2)
            # In standard form (defect <= 0)
            g = (1 - torch.cos(x)) * (1 + (y - 1.) ** 2) - self.epsilon
        elif self.problem_type == "Concave":
            f = torch.sin(x) * (1 + (y - 1.) ** 2)
            # in standard form (defect <= 0)
            g = torch.cos(x) * (1 + (y - 1.) ** 2) - self.epsilon
        else:
            raise ValueError("Unknown problem type.")

        # Store the values in a CMPState as attributes
        state = cooper.CMPState(loss=f, ineq_defect=g)

        return state

### Bonus: concave $f(x, y)$ and $g(x, y)$.

Associated with `Toy2DCMP(problem_type="Concave")`

Consider a similar optimization problem to Eq. (1), also on $[0,\pi/2] \times [0,\infty]$:

$$\begin{align*}
\underset{x, y}{\text{min}}\quad f(x,y) &:= \text{sin}(x) \ \big(1+(y - 1)^2\big) & \tag{2} \\
s.t. \quad  g(x,y) &:= \text{cos}(x)\ \big(1+(y-1)^2\big) - \epsilon \leq 0 & \\
\end{align*}$$

given some $\epsilon \geq 0 $.
$f$ and $g$ are concave functions with respect to $x$ in this case.

## Widget <a class="anchor" id="widget"></a>

The following hidden cell implements a widget which shows the optimization dynamics
of Cooper when solving the problems in Eq. (1) and Bonus: Eq. (2).

The widget displays:
- Loss throughout training.
- Values of the Lagrange multiplier and constraint defect throughout training.
- Optimization path in $(x, y)$ space. The feasible set is highlighted on blue.
Contours for the loss function are drawn.
- Optimization path in $(f, g)$ space. The Pareto front formed between $f$ and $g$
is added as reference.

Control items:
- Problem type: Convex in Eq. (1) or Concave in Eq. (2)
- Number of iterations to train for.
- Initial values for $(x, y)$.
- Primal optimizer class (e.g. SGD) and its learning rate.
- Dual optimizer class and its learning rate.
- Whether to employ dual restarts.
- Whether to use ExtraGradient updates on the parameters and multiplier.

In [123]:
class Toy2DWidget():
    def __init__(self):

        # --------------------------------------- Create some control elements
        problem_type_dropdown = widgets.Dropdown(
            options=["Convex", "Concave"],
            description="Problem type",
        )
        epsilon_slider = widgets.FloatSlider(
            min=-0.1, max=1.1, step=0.05, value=0.7, description="Const. level"
        )
        primal_lr_slider = widgets.FloatLogSlider(
            base=10,
            min=-4,
            max=0,
            step=0.1,
            value=2e-2,
            description="Primal LR",
            continuous_update=False,
        )
        primal_optim_dropdown = widgets.Dropdown(
            value="SGD",
            options=["SGD", "SGDM_0.9", "Adam"],
            description="Primal opt.",
        )
        dual_lr_slider = widgets.FloatLogSlider(
            base=10,
            min=-4,
            max=0,
            step=0.1,
            value=5e-1,
            description="Dual LR",
            continuous_update=False,
        )
        dual_optim_dropdown = widgets.Dropdown(
            value="SGD",
            options=["SGD", "SGDM_0.9", "Adam"],
            description="Dual opt.",
        )
        x_slider = widgets.FloatSlider(
            min=0,
            max=np.pi / 2,
            step=0.01,
            value=0.9,
            description="x init.",
            continuous_update=False,
        )
        y_slider = widgets.FloatSlider(
            min=0,
            max=3.0,
            step=0.01,
            value=2.,
            description="y init.",
            continuous_update=False,
        )
        iters_textbox = widgets.BoundedIntText(
            min=1, max=10000, value=200, description="Max Iters"
        )
        restarts_checkbox = widgets.Checkbox(value=False, description="Dual restarts")
        extrapolation_checkbox = widgets.Checkbox(
            value=False, description="Extrapolation"
        )

        # --------------------------------- Indicate what each option observes
        widget = interactive(
            self.update,
            x=x_slider,
            y=y_slider,
            num_iters=iters_textbox,
            epsilon=epsilon_slider,
            problem_type=problem_type_dropdown,
            primal_lr=primal_lr_slider,
            primal_optim=primal_optim_dropdown,
            dual_lr=dual_lr_slider,
            dual_optim=dual_optim_dropdown,
            dual_restarts=restarts_checkbox,
            extrapolation=extrapolation_checkbox,
        )
        controls_layout = Layout(
            display="flex",
            flex_flow="row wrap",
            border="solid 2px",
            align_items="center",
            width="950px",
        )
        controls = HBox(widget.children[:-1], layout=controls_layout)
        output = widget.children[-1]
        display(VBox([controls, output]))

        # ------------------------------ Initialize the CMP and its formulation
        self.cmp = Toy2DCMP(is_constrained=True)
        self.formulation = cooper.LagrangianFormulation(self.cmp)

        # # Run the update a first time
        widget.update()

    def reset_problem(self, epsilon=None, problem_type="convex"):
        """Reset the cmp and formulation for new training loops."""

        self.cmp.problem_type = problem_type

        # Reset the state of the CMP. Update epsilon if necessary.
        self.cmp.epsilon = epsilon
        self.cmp.state = None

        # Reset multipliers
        self.formulation.ineq_multipliers = None
        self.formulation.eq_multipliers = None

    def update(
        self,
        problem_type,
        epsilon,
        num_iters,
        primal_optim,
        dual_optim,
        x,
        primal_lr,
        dual_lr,
        y,
        extrapolation,
        dual_restarts,
    ):

        # Initialize the figure
        self.fig = plt.figure(figsize=(15, 5))
        grid_specs = GridSpec(2, 3, figure=self.fig)
        self.loss_iter_axis = self.fig.add_subplot(grid_specs[0, 0])
        self.defect_iter_axis = self.fig.add_subplot(grid_specs[1, 0])
        self.xy_axis = self.fig.add_subplot(grid_specs[:, 1])
        self.loss_defect_axis = self.fig.add_subplot(grid_specs[:, 2])

        # Reset the state of cmp and formulation. Indicate the new epsilon.
        self.reset_problem(epsilon=epsilon, problem_type=problem_type)

        # Plot the loss contours. Done once as loss does not change with sliders.
        # The feasible set does change and is plotted in self.update.
        self.contour_params = self.loss_contours()

        # Plot the Pareto front.
        foo = self.plot_pareto_front()

        # Update the filled contour indicating the feasible set (x, y) space and
        # epsilon hline (f, g) space
        foo = self.plot_feasible_set()

        # New initialization
        params = torch.nn.Parameter(torch.tensor([[x, y]]))

        # Construct a new optimizer
        self.constrained_optimizer = self.create_optimizer(
            params=params,
            primal_optim=primal_optim,
            dual_optim=dual_optim,
            primal_lr=primal_lr,
            dual_lr=dual_lr,
            dual_restarts=dual_restarts,
            extrapolation=extrapolation,
        )

        state_history = self.train(params=params, num_iters=num_iters)
        self.update_trajectory_plots(state_history)

    def create_optimizer(
        self,
        params,
        primal_optim,
        primal_lr,
        dual_optim,
        dual_lr,
        dual_restarts,
        extrapolation,
    ):

        # Check if any optimizer has momentum and add to kwargs it if necessary
        primal_kwargs = {"lr": primal_lr}
        if primal_optim == "SGDM_0.9":
            primal_optim = "SGD"
            primal_kwargs["momentum"] = 0.9
        dual_kwargs = {"lr": dual_lr}
        if dual_optim == "SGDM_0.9":
            dual_optim = "SGD"
            dual_kwargs["momentum"] = 0.9

        # Indicate if we are using extrapolation
        if extrapolation:
            primal_optim = "Extra" + primal_optim
            dual_optim = "Extra" + dual_optim

        primal_optimizer = getattr(cooper.optim, primal_optim)(
            [params],
            **primal_kwargs,
        )
        dual_optimizer = cooper.optim.partial(
            getattr(cooper.optim, dual_optim),
            **dual_kwargs,
        )

        constrained_optimizer = cooper.ConstrainedOptimizer(
            formulation=self.formulation,
            primal_optimizer=primal_optimizer,
            dual_optimizer=dual_optimizer,
            dual_restarts=dual_restarts,
        )

        return constrained_optimizer

    def train(self, params, num_iters):
        """Train."""

        # Store CMPStates and parameter values throughout the optimization process
        state_history = tutorial_utils.StateLogger(
            save_metrics=["loss", "ineq_defect", "ineq_multipliers"]
        )

        for iter_num in range(num_iters):

            self.constrained_optimizer.zero_grad()
            lagrangian = self.formulation.composite_objective(self.cmp.closure, params)
            self.formulation.custom_backward(lagrangian)
            self.constrained_optimizer.step(self.cmp.closure, params)

            # Ensure parameters remain in the domain of the functions
            params[:, 0].data.clamp_(min=0, max=np.pi / 2)
            params[:, 1].data.clamp_(min=0)

            # Store optimization metrics at each step
            state_history.store_metrics(
                self.formulation,
                iter_num,
                partial_dict={"params": copy.deepcopy(params.data)},
            )

        return state_history

    def loss_contours(self):
        """Plot the loss contours."""
        # Initial contours for plot
        x_range = torch.tensor(np.linspace(0, np.pi / 2, 100))
        y_range = torch.tensor(np.linspace(0, 2.0, 100))
        grid_x, grid_y = torch.meshgrid(x_range, y_range, indexing="ij")

        grid_params = torch.stack([grid_x.flatten(), grid_y.flatten()], axis=1)
        all_states = self.cmp.closure(grid_params)
        loss_grid = all_states.loss.reshape(len(x_range), len(y_range))

        # Plot the contours
        CS = self.xy_axis.contour(
            grid_x,
            grid_y,
            loss_grid,
            levels=[0.05, 0.125, 0.25, 0.5, 1, 1.5],
            alpha=1.0,
            colors="gray"
        )

        # Add styling
        self.xy_axis.clabel(CS, inline=1)

        defect_grid = all_states.ineq_defect.reshape(len(x_range), len(y_range))
        return (grid_x, grid_y, defect_grid)

    def plot_pareto_front(self):
        """Plot the Pareto front in the loss vs defect plane. This part is done
        once."""
        # y parametrizes distance to front. Regardless of epsilon, y=0 poses a
        # non-dominated solution. x parametrizes the position on the Pareto front.
        x_range = torch.tensor(np.linspace(0, np.pi / 2, 100))
        y_range = torch.tensor(100 * [1.])
        all_states = self.cmp.closure(torch.stack([x_range, y_range], axis=1))
        self.pareto_front = (all_states.loss, all_states.ineq_defect.squeeze())
        self.loss_defect_axis.plot(
            self.pareto_front[0], self.pareto_front[1], c="black", alpha=0.7
        )

        # Add styling
        self.loss_defect_axis.set_xlabel(r"Objective $f$")
        self.loss_defect_axis.set_ylabel(r"Constraint $g$")

    def update_trajectory_plots(self, state_history):

        blue = style_utils.COLOR_DICT["blue"]
        red = style_utils.COLOR_DICT["red"]
        green = style_utils.COLOR_DICT["green"]
        yellow = style_utils.COLOR_DICT["yellow"]

        all_metrics = state_history.unpack_stored_metrics()
        cmap_vals = np.linspace(0, 1, len(all_metrics["loss"]))
        cmap_name = "viridis"

        # --------------------------------- Trajectory in x-y plane
        params_hist = np.stack(all_metrics["params"]).squeeze().reshape(-1, 2)

        SC = self.xy_axis.scatter(
            params_hist[:, 0], params_hist[:, 1], c=cmap_vals, cmap=cmap_name, s=20, alpha=0.5, zorder=10
        )
        # Add marker signaling the final iterate
        self.xy_axis.scatter(
            *params_hist[-1, :],
            marker="*",
            s=150,
            zorder=100,
            c=yellow,
        )
        self.xy_axis.set_xlabel(r"Param. $x$")
        self.xy_axis.set_ylabel(r"Param. $y$")
        self.xy_axis.set_title(r"Parameter $(x, y)$ space")
        # Constrain domain
        self.xy_axis.set_xlim(0, np.pi / 2)
        self.xy_axis.set_ylim(0, 2.0)

        # -------------------------------- Trajectory in loss-defect plane
        defects = np.stack(all_metrics["ineq_defect"]).squeeze()
        self.loss_defect_axis.scatter(all_metrics["loss"], defects, alpha=0.5, s=20, c=cmap_vals, cmap=cmap_name)
        # Add marker signaling the final iterate
        self.loss_defect_axis.scatter(
            all_metrics["loss"][-1], defects[-1], marker="*", s=150, zorder=10, c=yellow
        )
        self.loss_defect_axis.set_title(r"Loss vs. constraint $(f, g)$ space")
        self.loss_defect_axis.set_xlim(- 0.1, 1.2)
        self.loss_defect_axis.set_ylim(-self.cmp.epsilon - 0.1, 1.2 - self.cmp.epsilon)
        self.loss_defect_axis.set_aspect('equal')

        # -------------------------------- Loss history
        self.loss_iter_axis.plot(
            all_metrics["iters"], all_metrics["loss"], c=blue, linewidth=2
        )
        self.loss_iter_axis.set_title("Loss")
        self.loss_iter_axis.set_xlabel("Iteration")

        # -------------------------------- Multiplier and defect history
        self.defect_iter_axis.plot(
            all_metrics["iters"],
            defects,
            c=red,
            linewidth=2,
            label="Defect",
            zorder=10,
        )
        self.defect_iter_axis.plot(
            all_metrics["iters"],
            np.stack(all_metrics["ineq_multipliers"]).squeeze(),
            c=green,
            linewidth=2,
            label="Multiplier",
        )
        self.defect_iter_axis.set_xlabel("Iteration")
        self.defect_iter_axis.legend(
            ncol=2, loc="upper right", bbox_to_anchor=(0.8, 1.25)
        )

        # -------------------------------- Colorbar
        last = len(all_metrics["loss"])
        cbar = self.fig.colorbar(SC, label="Iteration", ax=self.loss_defect_axis, ticks=np.linspace(0., 1., 6))
        cbar.ax.set_yticklabels(np.arange(0, last + 1, last // 5))

        self.fig.tight_layout()
        # TODO: warning https://stackoverflow.com/questions/69999315/using-astropy-with-matplotlib-i-get-a-warning-to-call-gridfalse-first-due-to
        # self.fig.subplots_adjust(right=1.1)


    def plot_feasible_set(self):
        """Plot the feasible set."""
        # the values of g(x, y) have been computed in self.loss_contours for
        # the whole grid. The feasibility boundary changes based on the epsilon
        self.xy_axis.contourf(
            *self.contour_params,
            levels=[-10, 0],
            colors=style_utils.COLOR_DICT["blue"],
            alpha=0.1,
        )

        self.defect_iter_axis.axhline(0, c="gray", alpha=0.7, linestyle="--")

        # In loss vs defect plane, a line is drawn at the epsilon value
        y = torch.cat((torch.tensor([- self.cmp.epsilon]), self.pareto_front[1]))
        self.loss_defect_axis.fill_between(
            x=torch.cat((torch.tensor([1.2]), self.pareto_front[0])),
            y1=y,
            y2=0,
            where=y <= 0,
            step="mid",
            color=style_utils.COLOR_DICT["blue"],
            alpha=0.1,
        )
        self.loss_defect_axis.axhline(0, c="gray", alpha=0.7, linestyle="--")

widget = Toy2DWidget()

VBox(children=(HBox(children=(Dropdown(description='Problem type', options=('Convex', 'Concave'), value='Conve…

## References <a class="anchor" id="references"></a>

- Degrave, J.   and Korshunova, I. Why machine learning algorithms are hard to tune and how to fix it. Engraved,   [blog](www.engraved.blog/why-machine-learning-algorithms-are-hard-to-tune/), 2021.
- Degrave, J.   and  Korshunova, I. How we can make machine learning algorithms tunable. Engraved,   [blog](https://www.engraved.blog/how-we-can-make-machine-learning-algorithms-tunable/), 2021.