In [None]:
!pip install benchmarx==0.0.11 --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.7/75.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.6/166.6 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m42.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.8/224.8 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
# Imports
from benchmarx import Benchmark
from benchmarx.custom_optimizer import CustomOptimizer, State
from benchmarx.quadratic_problem import QuadraticProblem
import jax.numpy as jnp
from jax import random, grad

In [None]:
class MyGradientDescent(CustomOptimizer):
    def __init__(self, x_init, stepsize, problem, tol=1e-6, maxiter=400, label = 'MyGD'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )

    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        sol -= state.stepsize * grad(self.problem.f)(sol)
        state.iter_num += 1
        state.stepsize = 1 / (state.iter_num + 1)
        return sol, state

    def stop_criterion(self, sol, state: State) -> bool:
        return jnp.linalg.norm(grad(self.problem.f)(sol))**2 < self.tol


In [None]:
# Let's create a Quadratic Problem
n = 5
x_init = jnp.array([2.0, 1.0, 3.0, .0, .0])
problem = QuadraticProblem(
    n=n,
    mineig=1,
    maxeig=10
)

# Specify your own solver
my_gd_solver = MyGradientDescent(
    x_init=x_init,
    stepsize=1e-2,
    problem=problem,
    tol=1e-3,
    maxiter=300,
    label='MyGD'
)


benchmark = Benchmark(
    runs=2,
    problem=problem,
    methods=[
        {
            'MY_GRADIENT_DESCENT': my_gd_solver
        },
        {
            'GRADIENT_DESCENT_const_step': {
                'x_init' : x_init,
                'tol': 1e-3,
                'maxiter': 300,
                'stepsize' : 1e-1,
                'acceleration': False,
                'label': 'GD_const'
            }
        }
    ],
    metrics=[
        "nit",
        "x",
        "f",
        "grad"
    ],
)

In [None]:
result = benchmark.run()

In [None]:
result.plot(
    metrics=['f', 'x_norm', 'f_gap', 'x_gap', 'grad_norm']
)