In [1]:
import numpy as np
import copy

import ipywidgets as widgets
from IPython.display import display
from ipywidgets import interactive, HBox, VBox, Layout
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

%matplotlib inline

import torch

import cooper

import style_utils
import tutorial_utils

In [2]:
class Convex2dCMP(cooper.ConstrainedMinimizationProblem):
    def __init__(self, is_constrained=False, epsilon=1.0, problem_type="Convex"):
        self.problem_type = problem_type
        self.epsilon = epsilon
        super().__init__(is_constrained)

    def closure(self, params):
        """This function evaluates the objective function and constraint
        defect. It updates the attributes of this CMP based on the results."""

        x, y = params[:, 0], params[:, 1]

        if self.problem_type == "Convex":
            f = (1 - torch.sin(x)) * (1 + (y - 0.5) ** 2)
            # In standard form (defect <= 0)
            g = (1 - torch.cos(x)) * (1 + (y - 0.5) ** 2) - self.epsilon
        else:
            f = torch.sin(x) * (1 + (y - 0.5) ** 2)
            # in standard form (defect <= 0)
            g = torch.cos(x) * (1 + (y - 0.5) ** 2) - self.epsilon

        # Store the values in a CMPState as attributes
        state = cooper.CMPState(loss=f, ineq_defect=g, misc={"g": g})

        return state

In [18]:
class Toy2DWidget(VBox):
    def __init__(self):

        super().__init__()

        # --------------------------------------- Create some control elements
        problem_type_dropdown = widgets.Dropdown(
            options=["Convex", "Concave"],
            description="Problem type",
        )
        epsilon_slider = widgets.FloatSlider(
            min=-0.1, max=1.1, step=0.05, value=0.7, description="Const. level"
        )
        primal_lr_slider = widgets.FloatLogSlider(
            base=10,
            min=-4,
            max=0,
            step=0.1,
            value=2e-2,
            description="Primal LR",
            continuous_update=False,
        )
        primal_optim_dropdown = widgets.Dropdown(
            value="SGD",
            options=["SGD", "SGDM_0.9", "Adam"],
            description="Primal opt.",
        )
        dual_lr_slider = widgets.FloatLogSlider(
            base=10,
            min=-4,
            max=0,
            step=0.1,
            value=5e-1,
            description="Dual LR",
            continuous_update=False,
        )
        dual_optim_dropdown = widgets.Dropdown(
            value="SGD",
            options=["SGD", "SGDM_0.9", "Adam"],
            description="Dual opt.",
        )
        x_slider = widgets.FloatSlider(
            min=0,
            max=np.pi / 2,
            step=0.01,
            value=1.0,
            description="Initial x",
            continuous_update=False,
        )
        y_slider = widgets.FloatSlider(
            min=0,
            max=2.0,
            step=0.01,
            value=1.0,
            description="Initial y",
            continuous_update=False,
        )
        iters_textbox = widgets.BoundedIntText(
            min=1, max=10000, value=200, description="Iters"
        )
        restarts_checkbox = widgets.Checkbox(value=False, description="Dual restarts")
        extrapolation_checkbox = widgets.Checkbox(
            value=False, description="Extrapolation"
        )

        # Indicate what each option observes
        widget = interactive(
            self.update,
            x=x_slider,
            y=y_slider,
            num_iters=iters_textbox,
            epsilon=epsilon_slider,
            problem_type=problem_type_dropdown,
            primal_lr=primal_lr_slider,
            primal_optim=primal_optim_dropdown,
            dual_lr=dual_lr_slider,
            dual_optim=dual_optim_dropdown,
            dual_restarts=restarts_checkbox,
            extrapolation=extrapolation_checkbox,
        )
        controls_layout = Layout(
            display="flex",
            flex_flow="row wrap-reverse",
            border="solid 2px",
            align_items="stretch",
            width="1500px",
        )
        controls = HBox(widget.children[:-1], layout=controls_layout)
        output = widget.children[-1]
        display(VBox([controls, output]))

        # Initialize the CMP and its formulation
        self.cmp = Convex2dCMP(is_constrained=True)
        self.formulation = cooper.LagrangianFormulation(self.cmp)

        # # Run the update a first time
        # widget.update()

    def reset_problem(self, epsilon=None, problem_type="convex"):
        """Reset the cmp and formulation for new training loops."""

        self.cmp.problem_type = problem_type

        # Reset the state of the CMP. Update epsilon if necessary.
        self.cmp.epsilon = epsilon
        self.cmp.state = None

        # Reset multipliers
        self.formulation.ineq_multipliers = None
        self.formulation.eq_multipliers = None

    def update(
        self,
        y,
        extrapolation,
        dual_restarts,
        epsilon,
        x,
        primal_lr,
        dual_lr,
        problem_type,
        num_iters,
        primal_optim,
        dual_optim,
    ):

        # Initialize the figure
        self.fig = plt.figure(figsize=(16, 5), constrained_layout=True)
        grid_specs = GridSpec(2, 3, figure=self.fig)
        self.loss_iter_axis = self.fig.add_subplot(grid_specs[0, 0])
        self.defect_iter_axis = self.fig.add_subplot(grid_specs[1, 0])
        self.xy_axis = self.fig.add_subplot(grid_specs[:, 1])
        self.loss_defect_axis = self.fig.add_subplot(grid_specs[:, 2])

        # Reset the state of cmp and formulation. Indicate the new epsilon.
        self.reset_problem(epsilon=epsilon, problem_type=problem_type)

        # Plot the loss contours. Done once as loss does not change with sliders.
        # The feasible set does change and is plotted in self.update.
        self.contour_params = self.loss_contours()

        # Plot the pareto front.
        foo = self.plot_pareto_front()

        # Update the filled contour indicating the feasible set (x, y) space and
        # epsilon hline (f, g) space
        foo = self.plot_feasible_set()

        # New initialization
        params = torch.nn.Parameter(torch.tensor([[x, y]]))

        # Construct a new optimizer
        self.constrained_optimizer = self.create_optimizer(
            params=params,
            primal_optim=primal_optim,
            dual_optim=dual_optim,
            primal_lr=primal_lr,
            dual_lr=dual_lr,
            dual_restarts=dual_restarts,
            extrapolation=extrapolation,
        )

        state_history = self.train(params=params, num_iters=num_iters)
        self.update_trajectory_plots(state_history)

    def create_optimizer(
        self,
        params,
        primal_optim,
        primal_lr,
        dual_optim,
        dual_lr,
        dual_restarts,
        extrapolation,
    ):

        # Check if any optimizer has momentum and add to kwargs it if necessary
        primal_kwargs = {"lr": primal_lr}
        if primal_optim == "SGDM_0.9":
            primal_optim = "SGD"
            primal_kwargs["momentum"] = 0.9
        dual_kwargs = {"lr": dual_lr}
        if dual_optim == "SGDM_0.9":
            dual_optim = "SGD"
            dual_kwargs["momentum"] = 0.9

        # Indicate if we are using extrapolation
        if extrapolation:
            primal_optim = "Extra" + primal_optim
            dual_optim = "Extra" + dual_optim

        primal_optimizer = getattr(cooper.optim, primal_optim)(
            [params],
            **primal_kwargs,
        )
        dual_optimizer = cooper.optim.partial(
            getattr(cooper.optim, dual_optim),
            **dual_kwargs,
        )

        constrained_optimizer = cooper.ConstrainedOptimizer(
            formulation=self.formulation,
            primal_optimizer=primal_optimizer,
            dual_optimizer=dual_optimizer,
            dual_restarts=dual_restarts,
        )

        return constrained_optimizer

    def train(self, params, num_iters):
        """Train."""

        # Store CMPStates and parameter values throughout the optimization process
        state_history = tutorial_utils.StateLogger(
            save_metrics=["loss", "ineq_defect", "ineq_multipliers"]
        )

        for iter_num in range(num_iters):

            self.constrained_optimizer.zero_grad()
            lagrangian = self.formulation.composite_objective(self.cmp.closure, params)
            self.formulation.custom_backward(lagrangian)
            self.constrained_optimizer.step(self.cmp.closure, params)

            # Ensure parameters remain in the domain of the functions
            params[:, 0].data.clamp_(min=0, max=np.pi / 2)
            params[:, 1].data.clamp_(min=0)

            # Store optimization metrics at each step
            state_history.store_metrics(
                self.formulation,
                iter_num,
                partial_dict={"params": copy.deepcopy(params.data)},
            )

        return state_history

    def loss_contours(self):
        """Plot the loss contours."""
        # Initial contours for plot
        x_range = torch.tensor(np.linspace(0, np.pi / 2, 100))
        y_range = torch.tensor(np.linspace(0, 2.0, 100))
        grid_x, grid_y = torch.meshgrid(x_range, y_range, indexing="ij")

        grid_params = torch.stack([grid_x.flatten(), grid_y.flatten()], axis=1)
        all_states = self.cmp.closure(grid_params)
        loss_grid = all_states.loss.reshape(len(x_range), len(y_range))

        # Plot the contours
        CS = self.xy_axis.contour(
            grid_x,
            grid_y,
            loss_grid,
            levels=[0.05, 0.125, 0.25, 0.5, 1, 1.5, 2.0, 3.0],
            alpha=1.0,
            cmap="summer",
        )

        # Add styling
        self.xy_axis.clabel(CS, inline=1)

        g_grid = all_states.misc["g"].reshape(len(x_range), len(y_range))
        return (grid_x, grid_y, g_grid)

    def plot_pareto_front(self):
        """Plot the pareto front in the loss vs defect plane. This part is done
        once."""
        # y parametrizes distance to front. Regardless of epsilon, y=0 poses a
        # non-dominated solution. x parametrizes the position on the pareto front.
        x_range = torch.tensor(np.linspace(0, np.pi / 2, 100))
        y_range = torch.tensor(100 * [0.5])
        all_states = self.cmp.closure(torch.stack([x_range, y_range], axis=1))
        self.loss_defect_axis.plot(
            all_states.loss, all_states.misc["g"].squeeze(), c="gray", alpha=0.7
        )

        # Add styling
        self.loss_defect_axis.set_xlabel("Objective")
        self.loss_defect_axis.set_ylabel("Defect")

    def update_trajectory_plots(self, state_history):

        blue = style_utils.COLOR_DICT["blue"]
        red = style_utils.COLOR_DICT["red"]
        green = style_utils.COLOR_DICT["green"]
        yellow = style_utils.COLOR_DICT["yellow"]

        all_metrics = state_history.unpack_stored_metrics()

        # Trajectory in x-y plane
        params_hist = np.stack(all_metrics["params"]).squeeze().reshape(-1, 2)

        self.xy_axis.scatter(
            params_hist[:, 0], params_hist[:, 1], c=blue, s=10, alpha=0.7, zorder=10
        )
        # Add marker signaling the final iterate
        self.xy_axis.scatter(
            *params_hist[-1, :],
            marker="*",
            s=150,
            zorder=100,
            c=yellow,
        )
        self.xy_axis.set_xlabel("x")
        self.xy_axis.set_ylabel("y")
        self.xy_axis.set_title("Loss contours")
        # Constrain domain
        self.xy_axis.set_xlim(0, np.pi / 2)
        self.xy_axis.set_ylim(0, 2.0)

        # Trajectory in loss-defect plane
        g = np.stack(all_metrics["ineq_defect"]).squeeze()
        self.loss_defect_axis.scatter(all_metrics["loss"], g, s=2, c=blue)
        # Add marker signaling the final iterate
        self.loss_defect_axis.scatter(
            all_metrics["loss"][-1], g[-1], marker="*", s=150, zorder=10, c=yellow
        )

        # Loss history
        self.loss_iter_axis.plot(
            all_metrics["iters"], all_metrics["loss"], c=blue, linewidth=2
        )
        self.loss_iter_axis.set_title("Loss")
        self.loss_iter_axis.set_xlabel("Iteration")

        # Multiplier and defect history
        self.defect_iter_axis.plot(
            all_metrics["iters"],
            np.stack(all_metrics["ineq_defect"]).squeeze(),
            c=red,
            linewidth=2,
            label="Defect",
            zorder=10,
        )
        self.defect_iter_axis.plot(
            all_metrics["iters"],
            np.stack(all_metrics["ineq_multipliers"]).squeeze(),
            c=green,
            linewidth=2,
            label="Multiplier",
        )
        self.defect_iter_axis.set_xlabel("Iteration")
        self.defect_iter_axis.legend(
            ncol=2, loc="upper right", bbox_to_anchor=(0.8, 1.25)
        )

        self.defect_iter_axis.axhline(0, c="gray", alpha=0.7, linestyle="--")

    def plot_feasible_set(self):
        """Plot the feasible set."""
        # the values of g(x, y) have been computed in self.loss_contours for
        # the whole grid. The feasibility boundary changes based on the epsilon
        self.xy_axis.contourf(
            *self.contour_params,
            levels=[-10, 0],
            colors="blue",
            alpha=0.1,
        )

        # In loss vs defect plane, a line is drawn at the epsilon value
        self.loss_defect_axis.axhline(0, c="gray", alpha=0.7, linestyle="--")

In [20]:
w = Toy2DWidget()

VBox(children=(HBox(children=(FloatSlider(value=1.0, continuous_update=False, description='Initial y', max=2.0…