# Benchmarking `pyensmallen` for m-estimation

LBFGS works well for most smooth convex functions, notably a convex loss such as a likelihood. I generally find that optimization convergence is so fast that bootstrapping the entire procedure may be feasible.

Benchmarks against several popular optimization libraries.

In [2]:
import numpy as np
import pandas as pd
import pyensmallen
import scipy.optimize
import cvxpy as cp
from scipy.special import expit
import time

import statsmodels.api as sm

In [3]:
import jax
import jax.numpy as jnp
import optax
import functools
jax.config.update("jax_enable_x64", True)

## Data Generation

In [4]:
np.random.seed(42)
n, k = 1_000_000, 20

# Linear Regression Data
X_linear = np.random.randn(n, k)
print(true_params_linear := np.random.rand(k))
y_linear = X_linear @ true_params_linear

[0.51639859 0.94598022 0.23380001 0.55162275 0.97811966 0.24254699
 0.64702478 0.70271041 0.26476461 0.77362184 0.7817448  0.36874977
 0.72697004 0.06518613 0.72705723 0.38967364 0.03826155 0.39386005
 0.0438693  0.72142769]


## Linear Regression

### pyensmallen

In [7]:
def linear_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    residuals = X @ params - y.reshape(-1, 1)
    objective = np.sum(residuals**2)
    grad = 2 * X.T @ residuals
    gradient[:] = grad.flatten()
    return objective

linear_start = np.random.rand(k)

In [8]:
%%time
optimizer = pyensmallen.L_BFGS()
result_linear_ens = optimizer.optimize(
    lambda params, gradient: linear_objective(params, gradient, X_linear, y_linear),
    linear_start,
)


CPU times: user 6.25 s, sys: 182 ms, total: 6.43 s
Wall time: 451 ms


### scipy

In [9]:
%%time
result_linear_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum((X_linear @ b - y_linear) ** 2),
    x0=linear_start,
    jac=lambda b: 2 * X_linear.T @ (X_linear @ b - y_linear),
).x


CPU times: user 1min 34s, sys: 3.01 s, total: 1min 37s
Wall time: 6.57 s


### cvxpy

In [10]:
%%time
b_linear = cp.Variable(k)
cost_linear = cp.norm(X_linear @ b_linear - y_linear, p=2) ** 2 / n
prob_linear = cp.Problem(cp.Minimize(cost_linear))
prob_linear.solve(solver=cp.SCS)


CPU times: user 11.3 s, sys: 3.88 s, total: 15.2 s
Wall time: 15.1 s


np.float64(3.813432269764509e-16)

### jax + optax

In [11]:
%%time
X_jnp, y_jnp = jnp.array(X_linear), jnp.array(y_linear)

def compute_loss(beta):
    y_pred = jnp.dot(X_jnp, beta)
    loss = jnp.mean((y_pred - y_jnp) ** 2)
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(compute_loss)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=compute_loss
    )
    params = optax.apply_updates(params, updates)


CPU times: user 4.92 s, sys: 127 ms, total: 5.05 s
Wall time: 5.06 s


### closed form

In [12]:
%%time
np_lstsq_sol = np.linalg.lstsq(X_linear, y_linear, rcond=None)[0]

CPU times: user 1.3 s, sys: 48.2 ms, total: 1.35 s
Wall time: 439 ms


In [13]:
%%time
sm_lstsq_sol = sm.OLS(y_linear, X_linear).fit().params

CPU times: user 7.96 s, sys: 370 ms, total: 8.33 s
Wall time: 1.87 s


## Comparison

In [14]:
lin_df = pd.DataFrame({
    "true": true_params_linear,
    "ensmallen": result_linear_ens,
    "scipy": result_linear_scipy,
    "cvxpy": b_linear.value,
    "jax": params,
    "np_lstsq": np_lstsq_sol,
    "sm_lstsq": sm_lstsq_sol,
})
lin_df

Unnamed: 0,true,ensmallen,scipy,cvxpy,jax,np_lstsq,sm_lstsq
0,0.516399,0.516399,0.516399,0.516399,0.516396,0.516399,0.516399
1,0.94598,0.94598,0.94598,0.94598,0.945986,0.94598,0.94598
2,0.2338,0.2338,0.2338,0.2338,0.233798,0.2338,0.2338
3,0.551623,0.551623,0.551623,0.551623,0.551621,0.551623,0.551623
4,0.97812,0.97812,0.97812,0.97812,0.978121,0.97812,0.97812
5,0.242547,0.242547,0.242547,0.242547,0.242548,0.242547,0.242547
6,0.647025,0.647025,0.647025,0.647025,0.647029,0.647025,0.647025
7,0.70271,0.70271,0.70271,0.70271,0.70271,0.70271,0.70271
8,0.264765,0.264765,0.264765,0.264765,0.264757,0.264765,0.264765
9,0.773622,0.773622,0.773622,0.773622,0.773628,0.773622,0.773622


## Logistic Regression

In [15]:
# Logistic Regression Data
n, k = 10_000, 20
X_logistic = np.random.randn(n, k)
print(true_params_logistic := np.random.rand(k))
p = expit(X_logistic @ true_params_logistic)
y_logistic = np.random.binomial(1, p)

[0.24622957 0.77088703 0.99525454 0.92697675 0.85896413 0.96032642
 0.28103909 0.74634994 0.37044719 0.91379069 0.76225183 0.34217879
 0.09916425 0.81927817 0.48192179 0.49853671 0.91310665 0.24428751
 0.38247698 0.75367349]


### pyensmallen

In [16]:
def logistic_objective(params, gradient, X, y):
    z = X @ params
    h = expit(z)
    objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))
    if np.isnan(objective):
        objective = np.inf
    grad = X.T @ (h - y)
    gradient[:] = grad
    return objective

logistic_start = np.random.rand(k)

In [17]:
%%time
X_logistic2 = np.ascontiguousarray(
    X_logistic
)  # Ensure C-contiguous array for better performance
y_logistic2 = y_logistic.ravel()

optimizer = pyensmallen.L_BFGS()
result_logistic_ens = optimizer.optimize(
    lambda params, gradient: logistic_objective(
        params, gradient, X_logistic2, y_logistic2
    ),
    logistic_start,
)


CPU times: user 9.43 ms, sys: 27 μs, total: 9.45 ms
Wall time: 9.02 ms


### scipy

In [18]:
%%time
result_logistic_scipy = scipy.optimize.minimize(
    fun=lambda b: -np.sum(
        y_logistic * np.log(expit(X_logistic @ b))
        + (1 - y_logistic) * np.log(1 - expit(X_logistic @ b))
    ),
    x0=logistic_start,
    jac=lambda b: X_logistic.T @ (expit(X_logistic @ b) - y_logistic),
).x


CPU times: user 38.8 ms, sys: 23 μs, total: 38.8 ms
Wall time: 38 ms




### cvxpy

In [19]:
%%time
b_logistic = cp.Variable(k)
log_likelihood = cp.sum(
    cp.multiply(y_logistic, X_logistic @ b_logistic)
    - cp.logistic(X_logistic @ b_logistic)
)
prob_logistic = cp.Problem(cp.Maximize(log_likelihood))
prob_logistic.solve(solver=cp.SCS)


CPU times: user 1.19 s, sys: 23.2 ms, total: 1.21 s
Wall time: 1.21 s


np.float64(-3596.331511780124)

### statsmodels

does IRLS

In [20]:
%%time
sm_logit_res = sm.Logit(y_logistic, X_logistic).fit().params

Optimization terminated successfully.
         Current function value: 0.359633
         Iterations 7
CPU times: user 247 ms, sys: 1.62 ms, total: 248 ms
Wall time: 38.3 ms


In [21]:
%%time
X_jnp, y_jnp = jnp.array(X_logistic), jnp.array(y_logistic)

def logistic_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    h = jax.scipy.special.expit(z)
    loss = -jnp.sum(y_jnp * jnp.log(h) + (1 - y_jnp) * jnp.log1p(-h))
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(logistic_likelihood)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=logistic_likelihood
    )
    params = optax.apply_updates(params, updates)

CPU times: user 4.34 s, sys: 71.9 ms, total: 4.41 s
Wall time: 4.35 s


## comparison

In [22]:
logit_df = pd.DataFrame(
    {
        "true": true_params_logistic,
        "ensmallen": result_logistic_ens,
        "scipy": result_logistic_scipy,
        "cvxpy": b_logistic.value,
        "jax": params,
        "sm_logit": sm_logit_res,
    }
)
logit_df

Unnamed: 0,true,ensmallen,scipy,cvxpy,jax,sm_logit
0,0.24623,0.266128,0.266128,0.266125,0.214509,0.266128
1,0.770887,0.772562,0.772562,0.772553,0.750707,0.772562
2,0.995255,0.99279,0.99279,0.992778,1.049312,0.99279
3,0.926977,0.950981,0.950981,0.95097,0.99383,0.950981
4,0.858964,0.802078,0.802078,0.802069,0.770684,0.802078
5,0.960326,0.968322,0.968322,0.968311,1.030593,0.968322
6,0.281039,0.259412,0.259412,0.259409,0.197726,0.259412
7,0.74635,0.72167,0.72167,0.721662,0.730997,0.72167
8,0.370447,0.374674,0.374674,0.37467,0.383463,0.374674
9,0.913791,0.909638,0.909638,0.909628,0.929301,0.909638


## Poisson Regression

In [23]:
n, k = 100_000, 10
# Poisson Regression Data
X_poisson = np.random.randn(n, k)
print(true_params_poisson := np.random.rand(k))
lambda_ = np.exp(X_poisson @ true_params_poisson)
y_poisson = np.random.poisson(lambda_)

[0.21058295 0.37114233 0.11747555 0.85610033 0.10420509 0.38705405
 0.85924089 0.35038735 0.99516334 0.38716221]


## pyensmallen

In [24]:
def poisson_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    y = y.reshape(-1, 1)
    Xbeta = X @ params
    lambda_ = np.exp(Xbeta)
    objective = np.sum(lambda_ - np.multiply(y, np.log(lambda_)))
    # Compute the gradient
    grad = X.T @ (lambda_ - y)
    gradient[:] = grad.ravel()
    return objective

poisson_start = np.random.rand(k)

In [25]:
%%time
optimizer = pyensmallen.L_BFGS()
result_poisson_ens = optimizer.optimize(
    lambda params, gradient: poisson_objective(params, gradient, X_poisson, y_poisson),
    poisson_start,
)


CPU times: user 1.49 s, sys: 7.38 ms, total: 1.49 s
Wall time: 111 ms


### scipy

In [26]:
%%time
result_poisson_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum(np.exp(X_poisson @ b) - y_poisson * (X_poisson @ b)),
    x0=poisson_start,
    jac=lambda b: X_poisson.T @ (np.exp(X_poisson @ b) - y_poisson),
).x


CPU times: user 8.57 s, sys: 4.37 ms, total: 8.57 s
Wall time: 586 ms


### cvxpy

fails

In [27]:
# %%capture
# b_poisson = cp.Variable(k)
# z = X_poisson @ b_poisson
# cost_poisson = cp.sum(cp.exp(z) - cp.multiply(y_poisson, z)) / n
# prob_poisson = cp.Problem(cp.Minimize(cost_poisson))
# prob_poisson.solve(solver=cp.SCS)


Runs out of memory.

### statsmodels

In [28]:
%%time
sm_poisson_res = sm.Poisson(y_poisson, X_poisson).fit().params

Optimization terminated successfully.
         Current function value: 1.383029
         Iterations 28
CPU times: user 7.08 s, sys: 13.5 ms, total: 7.09 s
Wall time: 515 ms


## jax
switch to adam since lbfgs performs poorly (!TODO why?)

In [29]:
%%time
X_jnp, y_jnp = jnp.array(X_poisson), jnp.array(y_poisson)

def poisson_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    lambda_ = jnp.exp(z)
    loss = jnp.sum(lambda_ - y_jnp * z)
    return loss


solver = optax.adam(1e-2)
adam_params = jnp.array(poisson_start)
opt_state = solver.init(adam_params)

@jax.jit
def update_step(params, opt_state):
    loss, grads = jax.value_and_grad(poisson_likelihood)(params)
    updates, opt_state = solver.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


for i in range(1000):
    adam_params, opt_state, loss = update_step(adam_params, opt_state)

CPU times: user 2.21 s, sys: 88.6 ms, total: 2.3 s
Wall time: 1.82 s


In [30]:
poi_df = pd.DataFrame(
  {
      "true": true_params_poisson,
      "ensmallen": result_poisson_ens,
      "scipy": result_poisson_scipy,
      "sm_poisson": sm_poisson_res,
      "jax_adam": adam_params,
  }
)
poi_df

Unnamed: 0,true,ensmallen,scipy,sm_poisson,jax_adam
0,0.210583,0.208376,0.208376,0.208376,0.208376
1,0.371142,0.368691,0.368691,0.368691,0.368691
2,0.117476,0.116927,0.116927,0.116927,0.116927
3,0.8561,0.855668,0.855668,0.855668,0.855668
4,0.104205,0.103997,0.103997,0.103997,0.103997
5,0.387054,0.388284,0.388284,0.388284,0.388284
6,0.859241,0.861597,0.861597,0.861597,0.861597
7,0.350387,0.350552,0.350552,0.350552,0.350552
8,0.995163,0.996572,0.996572,0.996572,0.996572
9,0.387162,0.386803,0.386803,0.386803,0.386803
