In [1]:
%pip install benchmarx==0.0.11 --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.7/75.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.6/166.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.8/224.8 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [9]:
from benchmarx import Benchmark
from benchmarx.custom_optimizer import CustomOptimizer, State
from benchmarx._problems.log_regr import LogisticRegression

import jax
import jax.numpy as jnp
import random
import urllib.request

In [3]:
class CSGD(CustomOptimizer):
    """
    CSGD for LogLoss
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'CSGD'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.batch = 10
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )


    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        n = self.problem.n_train
        d = self.problem.d_train

        full_grad = jax.grad(self.problem.f)(sol)

        indices = random.sample(
            population=list(range(d)),
            k=self.batch
        )
        g = jnp.zeros(d)
        for ind in indices:
            g = g.at[ind].set(full_grad[ind])
        sol = sol - self.stepsize * d / self.batch * g
        state.iter_num += 1
        return sol, state

    def stop_criterion(self, sol, state: State) -> bool:
        return False


In [4]:
class SGD(CustomOptimizer):
    """
    SGD for LogLoss
    """
    def __init__(self, x_init, stepsize, problem, tol=0, maxiter=1000, label = 'SGD'):
        params = {
            'x_init': x_init,
            'tol': tol,
            'maxiter': maxiter,
            'stepsize': stepsize
        }
        self.stepsize = stepsize
        self.problem = problem
        self.maxiter = maxiter
        self.batch = 1
        self.tol = tol
        super().__init__(params=params, x_init=x_init, label=label)

    def init_state(self, x_init, *args, **kwargs) -> State:
        return State(
            iter_num=1,
            stepsize=self.stepsize
        )


    def update(self, sol, state: State) -> tuple([jnp.array, State]):
        n = self.problem.y_train.shape[0] // 10
        d = self.problem.d_train
        indices = random.sample(
            population=list(range(n)),
            k=self.batch
        )
        g = jnp.zeros(d)
        for ind in indices:
            g += self.problem.grad_log_loss_ind(sol, ind)
        sol = sol - self.stepsize / self.batch * g
        state.iter_num += 1
        return sol, state

    def stop_criterion(self, sol, state: State) -> bool:
        return False


In [19]:
problem = LogisticRegression(
    info="Logistic Regression problem on 'breast cancer' dataset, l2-regularization",
    problem_type="breast_cancer"
)

In [20]:
# Estimation of the Lipschitz constant of the gradient
L = problem.estimate_L()

# l2 regularizer
regularizer = lambda w: L/2500*jnp.linalg.norm(w, ord=2)**2
problem.regularizer = regularizer

In [21]:
key = jax.random.PRNGKey(110520)
x_init = jax.random.uniform(key, minval=0, maxval=1, shape=(problem.d_train,))
nit = 250

In [22]:
csgd_solver = CSGD(
    x_init=x_init,
    stepsize=34/L,
    problem=problem,
    tol=0,
    maxiter=nit,
    label="CSGD"
)

In [23]:
benchmark = Benchmark(
    runs=2,
    problem=problem,
    methods=[{
        "CSGD": csgd_solver
    },
    {
        'GRADIENT_DESCENT_const_step': {
            'x_init' : x_init,
            'tol': 0,
            'maxiter': nit,
            'stepsize' : 20/L,
            'acceleration': False,
            'label': 'GD'
        },
    },
    {
        'GRADIENT_DESCENT_adapt_step': {
            'x_init' : x_init,
            'tol': 0,
            'maxiter': nit,
            'stepsize' : lambda iter_num: 20/(L + iter_num/20),
            'acceleration': False,
            'label': 'GD adapt step'
        },
    }
    ],
    metrics=[
        "f",
    ],
)

In [24]:
result = benchmark.run()

In [25]:
result.plot()