In [None]:
%pip install benchmarx==0.0.11 --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.7/75.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.6/166.6 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.8/224.8 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
from benchmarx import Benchmark
from benchmarx.quadratic_problem import QuadraticProblem
from benchmarx.custom_optimizer import CustomOptimizer, State
from benchmarx.metrics import CustomMetric

import jax
import jax.numpy as jnp
import random

In [None]:
class MirrorDescent(CustomOptimizer):
    """
    Mirror Descent algorithm on the standart simplex
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'MD'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )

    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        Ax = self.problem.A @ sol
        y = [sol[i] * jnp.exp(-state.stepsize * Ax[i]) for i in range(self.problem.n)]
        sol = jnp.array(y) / sum(y)
        state.iter_num += 1
        return sol, state

    def stop_criterion(self, sol, state: State) -> bool:
        return False

In [None]:
class CSGD_proj(CustomOptimizer):
    """
    Coordinate SGD on standart simplex
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'GD_proj'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )

    def proj(self, x):
        """
        Euclidian projection on the standart simplex
        """
        x_sort = sorted(x, reverse=True)
        rho = 0
        s = x_sort[0]
        s_ans = s

        for i in range(1, len(x_sort)):
            s += x_sort[i]
            if x_sort[i] + 1 / (i + 1) * (1 - s) > 0:
                rho = i
                s_ans = s

        l = 1 / (rho + 1) * (1 - s_ans)
        ans = jnp.zeros(len(x_sort))
        for i in range(len(ans)):
            ans = ans.at[i].set(max(x[i] + l, 0))
        return ans

    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        Ax = self.problem.A @ sol
        g = jnp.zeros(Ax.shape[0])
        ind = random.randint(a=0, b=g.shape[0]-1)
        g = g.at[ind].set(Ax[ind])
        sol = self.proj(sol - self.stepsize * g)
        state.iter_num += 1

        return sol, state

    def stop_criterion(self, sol, state: State) -> bool:
        return False

In [None]:
# Lets generate Quadratic problem with
# Lipschitz constant of the gradient L=1000
# and constant of strong convexity mu=1

L = 1000
mu = 1
d = 10
problem = QuadraticProblem(
    n=d,
    b=jnp.zeros(d),
    mineig=mu,
    maxeig=L,
    info=f"QP"
)



In [None]:
key = jax.random.PRNGKey(110520)
x_init = jax.random.uniform(key, minval=0, maxval=1, shape=(d,)) / d
nit = 200

In [None]:
# Specify your own mirror-descent-solver
md_solver = MirrorDescent(
    x_init=x_init,
    stepsize=1/L,
    problem=problem,
    tol=0,
    maxiter=nit,
    label='MD'
)

In [None]:
# Specify your own Coordinate-SGD-solver
csgd_solver = CSGD_proj(
    x_init=x_init,
    stepsize=1/L,
    problem=problem,
    tol=0,
    maxiter=nit,
    label='CSGD_proj'
)

In [None]:
# Custom Metric
gap = CustomMetric(
    func=lambda x: x.T @ problem.A @ x - jnp.min(problem.A @ x),
    label="main_gap"
)

In [None]:
benchmark = Benchmark(
    runs=3,
    problem=problem,
    methods=[{
        "MirrorDescent": md_solver
    },
    {
        "CSGD_proj": csgd_solver
    }
    ],
    metrics=[
        "nit",
        "x",
        "f",
        "grad",
        gap
    ],
)

In [None]:
result = benchmark.run()

In [None]:
# Look at the fabulous plot
result.plot(
    metrics=[gap, 'f', 'x_norm', 'f_gap', 'x_gap', 'grad_norm']
)