In [4]:
%pip install benchmarx --quiet


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [11]:
from benchmarx import Benchmark, QuadraticProblem, CustomOptimizer, Plotter
from benchmarx.src.metrics import CustomMetric
from benchmarx.src.custom_optimizer import State

import jax
import jax.numpy as jnp
import random

In [8]:
class MirrorDescent(CustomOptimizer):
    """
    Mirror Descent algorithm on the standart simplex
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'MD'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )
    
    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        Ax = self.problem.A @ sol
        y = [sol[i] * jnp.exp(-state.stepsize * Ax[i]) for i in range(self.problem.n)]
        sol = jnp.array(y) / sum(y)
        state.iter_num += 1
        return sol, state
    
    def stop_criterion(self, sol, state: State) -> bool:
        return False

In [9]:
class CSGD_proj(CustomOptimizer):
    """
    Coordinate SGD on standart simplex
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'GD_proj'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )
    
    def proj(self, x):
        """
        Euclidian projection on the standart simplex
        """
        x_sort = sorted(x, reverse=True)
        rho = 0
        s = x_sort[0]
        s_ans = s

        for i in range(1, len(x_sort)):
            s += x_sort[i]
            if x_sort[i] + 1 / (i + 1) * (1 - s) > 0:
                rho = i
                s_ans = s

        l = 1 / (rho + 1) * (1 - s_ans)
        ans = jnp.zeros(len(x_sort))
        for i in range(len(ans)):
            ans = ans.at[i].set(max(x[i] + l, 0))
        return ans

    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        Ax = self.problem.A @ sol
        g = jnp.zeros(Ax.shape[0])
        ind = random.randint(a=0, b=g.shape[0]-1)
        g = g.at[ind].set(Ax[ind])
        sol = self.proj(sol - self.stepsize * g)
        state.iter_num += 1

        return sol, state
    
    def stop_criterion(self, sol, state: State) -> bool:
        return False

In [10]:
# Lets generate Quadratic problem with 
# Lipschitz constant of the gradient L=1000
# and constant of strong convexity mu=1

L = 1000
mu = 1
d = 10
problem = QuadraticProblem(
    n=d,
    b=jnp.zeros(d),
    mineig=mu,
    maxeig=L,
    info=f"QP"
)

In [12]:
key = jax.random.PRNGKey(110520)
x_init = jax.random.uniform(key, minval=0, maxval=1, shape=(d,)) / d
nit = 200

In [13]:
md_solver = MirrorDescent(
        x_init=x_init,
        stepsize=1/L,
        problem=problem,
        tol=0,
        maxiter=nit,
        label='MD'
    )

In [14]:
csgd_solver = CSGD_proj(
        x_init=x_init,
        stepsize=1/L,
        problem=problem,
        tol=0,
        maxiter=nit,
        label='CSGD_proj'
    )

In [15]:
# Custom Metric
gap = CustomMetric(
        func=lambda x: x.T @ problem.A @ x - jnp.min(problem.A @ x),
        label="main_gap"
    )

In [16]:
benchmark = Benchmark(
        runs=3,
        problem=problem,
        methods=[{
            "MirrorDescent": md_solver
        },
        {
            "CSGD_proj": csgd_solver
        }
        ],
        metrics=[
            "nit",
            "history_x",
            "history_f",
            "history_df",
            gap
        ],
    )

In [17]:
result = benchmark.run()

In [18]:
result.save('custom_method_data.json')

plotter = Plotter(
        data_path="custom_method_data.json",
    )

# create a fabulous plot 
plotter.plot_plotly(
    metrics=["Solution norm", "Distance to the optimum", "Primal gap", "Gradient norm", "Function value", gap]
)