# Benchmarking `pyensmallen` for m-estimation

LBFGS works well for most smooth convex functions, notably a convex loss such as a likelihood. I generally find that optimization convergence is so fast that bootstrapping the entire procedure may be feasible.

Benchmarks against several popular optimization libraries.

In [1]:
import numpy as np
import pandas as pd
import pyensmallen
import scipy.optimize
import cvxpy as cp
from scipy.special import expit
import time

import statsmodels.api as sm
import nlopt

In [2]:
import jax
import jax.numpy as jnp
import optax
import functools
jax.config.update("jax_enable_x64", True)

## Data Generation

In [3]:
np.random.seed(42)
n, k = 1_000_000, 20

# Linear Regression Data
X_linear = np.random.randn(n, k)
print(true_params_linear := np.random.rand(k))
y_linear = X_linear @ true_params_linear

[0.51639859 0.94598022 0.23380001 0.55162275 0.97811966 0.24254699
 0.64702478 0.70271041 0.26476461 0.77362184 0.7817448  0.36874977
 0.72697004 0.06518613 0.72705723 0.38967364 0.03826155 0.39386005
 0.0438693  0.72142769]


## Linear Regression

### pyensmallen

In [4]:
def linear_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    residuals = X @ params - y.reshape(-1, 1)
    objective = np.sum(residuals**2)
    grad = 2 * X.T @ residuals
    gradient[:] = grad.flatten()
    return objective

linear_start = np.random.rand(k)

In [5]:
%%time
optimizer = pyensmallen.L_BFGS()
result_linear_ens = optimizer.optimize(
    lambda params, gradient: linear_objective(params, gradient, X_linear, y_linear),
    linear_start,
)


CPU times: user 6.3 s, sys: 191 ms, total: 6.49 s
Wall time: 427 ms


### nlopt

In [6]:
%%time
opt = nlopt.opt(nlopt.LD_LBFGS, k)
opt.set_min_objective(lambda params, gradient: linear_objective(params, gradient, X_linear, y_linear))
result_linear_nlopt = opt.optimize(linear_start)


CPU times: user 10.8 s, sys: 336 ms, total: 11.1 s
Wall time: 713 ms


### scipy

In [7]:
%%time
result_linear_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum((X_linear @ b - y_linear) ** 2),
    x0=linear_start,
    jac=lambda b: 2 * X_linear.T @ (X_linear @ b - y_linear),
).x


CPU times: user 1min 23s, sys: 2.12 s, total: 1min 25s
Wall time: 5.5 s


### cvxpy

In [8]:
%%time
b_linear = cp.Variable(k)
cost_linear = cp.norm(X_linear @ b_linear - y_linear, p=2) ** 2 / n
prob_linear = cp.Problem(cp.Minimize(cost_linear))
prob_linear.solve(solver=cp.SCS)


CPU times: user 11.4 s, sys: 3.47 s, total: 14.8 s
Wall time: 13.2 s


np.float64(3.813432269764509e-16)

### jax + optax

In [9]:
%%time
X_jnp, y_jnp = jnp.array(X_linear), jnp.array(y_linear)

def compute_loss(beta):
    y_pred = jnp.dot(X_jnp, beta)
    loss = jnp.mean((y_pred - y_jnp) ** 2)
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(compute_loss)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=compute_loss
    )
    params = optax.apply_updates(params, updates)


CPU times: user 5.98 s, sys: 51.4 ms, total: 6.04 s
Wall time: 4.53 s


### closed form

In [10]:
%%time
np_lstsq_sol = np.linalg.lstsq(X_linear, y_linear, rcond=None)[0]

CPU times: user 1.11 s, sys: 43.6 ms, total: 1.16 s
Wall time: 379 ms


In [11]:
%%time
sm_lstsq_sol = sm.OLS(y_linear, X_linear).fit().params

CPU times: user 7.76 s, sys: 357 ms, total: 8.11 s
Wall time: 1.64 s


## Comparison

In [12]:
lin_df = pd.DataFrame({
    "true": true_params_linear,
    "ensmallen": result_linear_ens,
    "nlopt": result_linear_nlopt,
    "scipy": result_linear_scipy,
    "cvxpy": b_linear.value,
    "jax": params,
    "np_lstsq": np_lstsq_sol,
    "sm_lstsq": sm_lstsq_sol,
})
lin_df

Unnamed: 0,true,ensmallen,nlopt,scipy,cvxpy,jax,np_lstsq,sm_lstsq
0,0.516399,0.516399,0.516399,0.516399,0.516399,0.516398,0.516399,0.516399
1,0.94598,0.94598,0.94598,0.94598,0.94598,0.945978,0.94598,0.94598
2,0.2338,0.2338,0.2338,0.2338,0.2338,0.233798,0.2338,0.2338
3,0.551623,0.551623,0.551623,0.551623,0.551623,0.551626,0.551623,0.551623
4,0.97812,0.97812,0.97812,0.97812,0.97812,0.978113,0.97812,0.97812
5,0.242547,0.242547,0.242547,0.242547,0.242547,0.242545,0.242547,0.242547
6,0.647025,0.647025,0.647025,0.647025,0.647025,0.647031,0.647025,0.647025
7,0.70271,0.70271,0.70271,0.70271,0.70271,0.702712,0.70271,0.70271
8,0.264765,0.264765,0.264765,0.264765,0.264765,0.264764,0.264765,0.264765
9,0.773622,0.773622,0.773622,0.773622,0.773622,0.773627,0.773622,0.773622


## Logistic Regression

In [13]:
# Logistic Regression Data
n, k = 10_000, 20
X_logistic = np.random.randn(n, k)
print(true_params_logistic := np.random.rand(k))
p = expit(X_logistic @ true_params_logistic)
y_logistic = np.random.binomial(1, p)

[0.89732722 0.75831337 0.56717163 0.34694649 0.80717784 0.58244036
 0.87299492 0.41850594 0.1207063  0.98533514 0.94064507 0.7165698
 0.34148517 0.21317874 0.24622957 0.77088703 0.99525454 0.92697675
 0.85896413 0.96032642]


### pyensmallen

In [14]:
def logistic_objective(params, gradient, X, y):
    z = X @ params
    h = expit(z)
    objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))
    if np.isnan(objective):
        objective = np.inf
    grad = X.T @ (h - y)
    gradient[:] = grad
    return objective

logistic_start = np.random.rand(k)

In [15]:
%%time
X_logistic2 = np.ascontiguousarray(
    X_logistic
)  # Ensure C-contiguous array for better performance
y_logistic2 = y_logistic.ravel()

optimizer = pyensmallen.L_BFGS()
result_logistic_ens = optimizer.optimize(
    lambda params, gradient: logistic_objective(
        params, gradient, X_logistic2, y_logistic2
    ),
    logistic_start,
)


CPU times: user 60.4 ms, sys: 28 μs, total: 60.5 ms
Wall time: 11.6 ms


### nlopt

In [16]:
%%time
opt = nlopt.opt(nlopt.LD_LBFGS, k)
opt.set_min_objective(lambda params, gradient: logistic_objective(
        params, gradient, X_logistic2, y_logistic2
    ))
result_logistic_nlopt = opt.optimize(logistic_start)


CPU times: user 80.9 ms, sys: 52 μs, total: 80.9 ms
Wall time: 16.1 ms


  objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))
  objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))
  objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))


### scipy

In [17]:
%%time
result_logistic_scipy = scipy.optimize.minimize(
    fun=lambda b: -np.sum(
        y_logistic * np.log(expit(X_logistic @ b))
        + (1 - y_logistic) * np.log(1 - expit(X_logistic @ b))
    ),
    x0=logistic_start,
    jac=lambda b: X_logistic.T @ (expit(X_logistic @ b) - y_logistic),
).x


CPU times: user 52 ms, sys: 3 μs, total: 52 ms
Wall time: 51.5 ms




### cvxpy

In [18]:
%%time
b_logistic = cp.Variable(k)
log_likelihood = cp.sum(
    cp.multiply(y_logistic, X_logistic @ b_logistic)
    - cp.logistic(X_logistic @ b_logistic)
)
prob_logistic = cp.Problem(cp.Maximize(log_likelihood))
prob_logistic.solve(solver=cp.SCS)


CPU times: user 1.06 s, sys: 35.8 ms, total: 1.09 s
Wall time: 1.09 s


np.float64(-3404.919683904648)

### statsmodels

does IRLS

In [19]:
%%time
sm_logit_res = sm.Logit(y_logistic, X_logistic).fit().params

Optimization terminated successfully.
         Current function value: 0.340492
         Iterations 7
CPU times: user 128 ms, sys: 0 ns, total: 128 ms
Wall time: 20.4 ms


In [20]:
%%time
X_jnp, y_jnp = jnp.array(X_logistic), jnp.array(y_logistic)

def logistic_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    h = jax.scipy.special.expit(z)
    loss = -jnp.sum(y_jnp * jnp.log(h) + (1 - y_jnp) * jnp.log1p(-h))
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(logistic_likelihood)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=logistic_likelihood
    )
    params = optax.apply_updates(params, updates)

CPU times: user 4.94 s, sys: 65.7 ms, total: 5 s
Wall time: 4.05 s


## comparison

In [21]:
logit_df = pd.DataFrame(
    {
        "true": true_params_logistic,
        "ensmallen": result_logistic_ens,
        "nlopt": result_logistic_nlopt,
        "scipy": result_logistic_scipy,
        "cvxpy": b_logistic.value,
        "jax": params,
        "sm_logit": sm_logit_res,
    }
)
logit_df

Unnamed: 0,true,ensmallen,nlopt,scipy,cvxpy,jax,sm_logit
0,0.897327,0.917879,0.917879,0.917879,0.917868,0.879663,0.917879
1,0.758313,0.780738,0.780738,0.780738,0.780729,0.790432,0.780738
2,0.567172,0.60736,0.60736,0.60736,0.607352,0.571163,0.60736
3,0.346946,0.363994,0.363994,0.363994,0.36399,0.352212,0.363994
4,0.807178,0.721937,0.721937,0.721937,0.721928,0.721815,0.721937
5,0.58244,0.593548,0.593548,0.593548,0.593541,0.540399,0.593548
6,0.872995,0.859396,0.859396,0.859396,0.859386,0.730609,0.859396
7,0.418506,0.362243,0.362243,0.362243,0.362239,0.36119,0.362243
8,0.120706,0.134975,0.134975,0.134975,0.134973,0.011258,0.134975
9,0.985335,0.993394,0.993394,0.993394,0.993382,0.84039,0.993394


## Poisson Regression

In [22]:
n, k = 100_000, 10
# Poisson Regression Data
X_poisson = np.random.randn(n, k)
print(true_params_poisson := np.random.rand(k))
lambda_ = np.exp(X_poisson @ true_params_poisson)
y_poisson = np.random.poisson(lambda_)

[0.46622737 0.27893256 0.1526658  0.04174791 0.46506248 0.03016092
 0.94782085 0.33601329 0.11498335 0.72554644]


## pyensmallen

In [23]:
def poisson_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    y = y.reshape(-1, 1)
    Xbeta = X @ params
    lambda_ = np.exp(Xbeta)
    objective = np.sum(lambda_ - np.multiply(y, np.log(lambda_)))
    # Compute the gradient
    grad = X.T @ (lambda_ - y)
    gradient[:] = grad.ravel()
    return objective

poisson_start = np.random.rand(k)

In [24]:
%%time
optimizer = pyensmallen.L_BFGS()
result_poisson_ens = optimizer.optimize(
    lambda params, gradient: poisson_objective(params, gradient, X_poisson, y_poisson),
    poisson_start,
)


CPU times: user 828 ms, sys: 0 ns, total: 828 ms
Wall time: 58.1 ms


### nlopt
fails

In [31]:
# %%time
# opt = nlopt.opt(nlopt.LD_LBFGS, k)
# opt.set_min_objective(lambda params, gradient: poisson_objective(params, gradient, X_poisson, y_poisson))
# opt.set_maxeval(100000)
# result_poisson_nlopt = opt.optimize(poisson_start)
# print(result_poisson_nlopt)

### scipy

In [26]:
%%time
result_poisson_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum(np.exp(X_poisson @ b) - y_poisson * (X_poisson @ b)),
    x0=poisson_start,
    jac=lambda b: X_poisson.T @ (np.exp(X_poisson @ b) - y_poisson),
).x


CPU times: user 5.08 s, sys: 0 ns, total: 5.08 s
Wall time: 355 ms


### cvxpy

fails

In [32]:
# %%capture
# b_poisson = cp.Variable(k)
# z = X_poisson @ b_poisson
# cost_poisson = cp.sum(cp.exp(z) - cp.multiply(y_poisson, z)) / n
# prob_poisson = cp.Problem(cp.Minimize(cost_poisson))
# prob_poisson.solve(solver=cp.SCS)


Runs out of memory.

### statsmodels

In [28]:
%%time
sm_poisson_res = sm.Poisson(y_poisson, X_poisson).fit().params

Optimization terminated successfully.
         Current function value: 1.361967
         Iterations 14
CPU times: user 3.83 s, sys: 0 ns, total: 3.83 s
Wall time: 274 ms


## jax
switch to adam since lbfgs performs poorly (!TODO why?)

In [29]:
%%time
X_jnp, y_jnp = jnp.array(X_poisson), jnp.array(y_poisson)

def poisson_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    lambda_ = jnp.exp(z)
    loss = jnp.sum(lambda_ - y_jnp * z)
    return loss


solver = optax.adam(1e-2)
adam_params = jnp.array(poisson_start)
opt_state = solver.init(adam_params)

@jax.jit
def update_step(params, opt_state):
    loss, grads = jax.value_and_grad(poisson_likelihood)(params)
    updates, opt_state = solver.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


for i in range(1000):
    adam_params, opt_state, loss = update_step(adam_params, opt_state)

CPU times: user 3.58 s, sys: 61.8 ms, total: 3.64 s
Wall time: 1.52 s


In [30]:
poi_df = pd.DataFrame(
  {
      "true": true_params_poisson,
      "ensmallen": result_poisson_ens,
      "scipy": result_poisson_scipy,
      "sm_poisson": sm_poisson_res,
      "jax_adam": adam_params,
  }
)
poi_df

Unnamed: 0,true,ensmallen,scipy,sm_poisson,jax_adam
0,0.466227,0.467852,0.467852,0.467852,0.467852
1,0.278933,0.275483,0.275483,0.275483,0.275483
2,0.152666,0.151015,0.151015,0.151015,0.151015
3,0.041748,0.039774,0.039774,0.039774,0.039774
4,0.465062,0.466311,0.466311,0.466311,0.466311
5,0.030161,0.032722,0.032722,0.032722,0.032722
6,0.947821,0.948469,0.948469,0.948469,0.948469
7,0.336013,0.336803,0.336803,0.336803,0.336803
8,0.114983,0.120188,0.120188,0.120188,0.120188
9,0.725546,0.728043,0.728043,0.728043,0.728043
