# Benchmarking `pyensmallen` for m-estimation

LBFGS works well for most smooth convex functions, notably a convex loss such as a likelihood. I generally find that optimization convergence is so fast that bootstrapping the entire procedure may be feasible.

Benchmarks against several popular optimization libraries. `pyensmallen` can serve as a fast drop-in replacement for in-memory optimisation problems where you would otherwise have used scipy or cvxpy.

In [1]:
import numpy as np
import pandas as pd
import pyensmallen
import scipy.optimize
import cvxpy as cp
from scipy.special import expit
import time

import statsmodels.api as sm

In [2]:
import jax
import jax.numpy as jnp
import optax
import functools
jax.config.update("jax_enable_x64", True)

## Data Generation

In [3]:
np.random.seed(42)
n, k = 1_000_000, 20

# Linear Regression Data
X_linear = np.random.randn(n, k)
print(true_params_linear := np.random.rand(k))
y_linear = X_linear @ true_params_linear

[0.51639859 0.94598022 0.23380001 0.55162275 0.97811966 0.24254699
 0.64702478 0.70271041 0.26476461 0.77362184 0.7817448  0.36874977
 0.72697004 0.06518613 0.72705723 0.38967364 0.03826155 0.39386005
 0.0438693  0.72142769]


## Linear Regression

### pyensmallen

In [4]:
def linear_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    residuals = X @ params - y.reshape(-1, 1)
    objective = np.sum(residuals**2)
    grad = 2 * X.T @ residuals
    gradient[:] = grad.flatten()
    return objective

linear_start = np.random.rand(k)

In [5]:
start_time = time.time()
optimizer = pyensmallen.L_BFGS()
result_linear_ens = optimizer.optimize(
    lambda params, gradient: linear_objective(params, gradient, X_linear, y_linear),
    linear_start,
)
end_time = time.time()
ensmallen_linear_time = end_time - start_time
print(f"pyensmallen linear regression time: {ensmallen_linear_time:.6f} seconds")

pyensmallen linear regression time: 0.469582 seconds


### scipy

In [6]:
start_time = time.time()
result_linear_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum((X_linear @ b - y_linear) ** 2),
    x0=linear_start,
    jac=lambda b: 2 * X_linear.T @ (X_linear @ b - y_linear),
).x
end_time = time.time()
scipy_linear_time = end_time - start_time
print(f"scipy linear regression time: {scipy_linear_time:.6f} seconds")

scipy linear regression time: 7.389548 seconds


### cvxpy

In [7]:
start_time = time.time()
b_linear = cp.Variable(k)
cost_linear = cp.norm(X_linear @ b_linear - y_linear, p=2) ** 2 / n
prob_linear = cp.Problem(cp.Minimize(cost_linear))
prob_linear.solve(solver=cp.SCS)
end_time = time.time()
cvxpy_linear_time = end_time - start_time
print(f"cvxpy linear regression time: {cvxpy_linear_time:.6f} seconds")


cvxpy linear regression time: 20.309397 seconds


### jax + optax

In [8]:
start_time = time.time()
X_jnp, y_jnp = jnp.array(X_linear), jnp.array(y_linear)

def compute_loss(beta):
    y_pred = jnp.dot(X_jnp, beta)
    loss = jnp.mean((y_pred - y_jnp) ** 2)
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(compute_loss)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=compute_loss
    )
    params = optax.apply_updates(params, updates)
end_time = time.time()
jax_linear_time = end_time - start_time
print(f"jax linear regression time: {jax_linear_time:.6f} seconds")


jax linear regression time: 3.526115 seconds


### closed form

In [9]:
start_time = time.time()
np_lstsq_sol = np.linalg.lstsq(X_linear, y_linear, rcond=None)[0]
np_lstsq_time = end_time - start_time
print(f"numpy lstsq linear regression time: {np_lstsq_time:.6f} seconds")

numpy lstsq linear regression time: -1.514380 seconds


In [10]:
start_time = time.time()
sm_lstsq_sol = sm.OLS(y_linear, X_linear).fit().params
end_time = time.time()
sm_lstsq_time = end_time - start_time
print(f"statsmodels lstsq linear regression time: {sm_lstsq_time:.6f} seconds")


statsmodels lstsq linear regression time: 2.020317 seconds


## Comparison

In [11]:
lin_results = {
    "true": true_params_linear,
    "ensmallen": result_linear_ens,
    "scipy": result_linear_scipy,
    "cvxpy": b_linear.value,
    "jax": params,
    "np_lstsq": np_lstsq_sol,
    "sm_lstsq": sm_lstsq_sol,
}
lin_times = {
    "ensmallen": ensmallen_linear_time,
    "scipy": scipy_linear_time,
    "cvxpy": cvxpy_linear_time,
    "jax": jax_linear_time,
    "np_lstsq": np_lstsq_time,
    "sm_lstsq": sm_lstsq_time,
}
lin_df = pd.DataFrame(lin_results).T
lin_df["time"] = lin_df.index.map(lin_times)
lin_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,time
true,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,
ensmallen,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,0.469582
scipy,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,7.389548
cvxpy,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,20.309397
jax,0.516398,0.945978,0.233798,0.551626,0.978113,0.242545,0.647031,0.702712,0.264764,0.773627,...,0.368748,0.726966,0.06518,0.727062,0.389668,0.038263,0.393856,0.043868,0.721436,3.526115
np_lstsq,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,-1.51438
sm_lstsq,0.516399,0.94598,0.2338,0.551623,0.97812,0.242547,0.647025,0.70271,0.264765,0.773622,...,0.36875,0.72697,0.065186,0.727057,0.389674,0.038262,0.39386,0.043869,0.721428,2.020317


Numpy is obviously fastest here since we're doing a closed form solution. `pyensmallen` is fastest among MSE minimisers (and also beats statsmodels, which adds a lot of overhead over numpy despite using closed form).

## Logistic Regression

In [12]:
# Logistic Regression Data
n, k = 100_000, 20
X_logistic = np.random.randn(n, k)
print(true_params_logistic := np.random.rand(k))
p = expit(X_logistic @ true_params_logistic)
y_logistic = np.random.binomial(1, p)

[0.37971646 0.76615664 0.97499895 0.41514807 0.72909192 0.49428436
 0.14184926 0.21549712 0.66351905 0.62357696 0.16687515 0.32194925
 0.50776434 0.61101253 0.80298513 0.61154467 0.05334338 0.17495924
 0.08381726 0.10534799]


### pyensmallen

In [13]:
def logistic_objective(params, gradient, X, y):
    z = X @ params
    h = expit(z)
    objective = -np.sum(y * np.log(h) + (1 - y) * np.log1p(-h))
    if np.isnan(objective):
        objective = np.inf
    grad = X.T @ (h - y)
    gradient[:] = grad
    return objective

logistic_start = np.random.rand(k)

In [14]:
start_time = time.time()
X_logistic2 = np.ascontiguousarray(
    X_logistic
)  # Ensure C-contiguous array for better performance
y_logistic2 = y_logistic.ravel()

optimizer = pyensmallen.L_BFGS()
result_logistic_ens = optimizer.optimize(
    lambda params, gradient: logistic_objective(
        params, gradient, X_logistic2, y_logistic2
    ),
    logistic_start,
)
end_time = time.time()
ensmallen_logistic_time = end_time - start_time
print(f"pyensmallen logistic regression time: {ensmallen_logistic_time:.6f} seconds")

pyensmallen logistic regression time: 0.138552 seconds


### scipy

In [15]:
start_time = time.time()
result_logistic_scipy = scipy.optimize.minimize(
    fun=lambda b: -np.sum(
        y_logistic * np.log(expit(X_logistic @ b))
        + (1 - y_logistic) * np.log(1 - expit(X_logistic @ b))
    ),
    x0=logistic_start,
    jac=lambda b: X_logistic.T @ (expit(X_logistic @ b) - y_logistic),
).x
end_time = time.time()
scipy_logistic_time = end_time - start_time
print(f"scipy logistic regression time: {scipy_logistic_time:.6f} seconds")

  + (1 - y_logistic) * np.log(1 - expit(X_logistic @ b))
  + (1 - y_logistic) * np.log(1 - expit(X_logistic @ b))


scipy logistic regression time: 0.783805 seconds


### cvxpy

In [16]:
start_time = time.time()
b_logistic = cp.Variable(k)
log_likelihood = cp.sum(
    cp.multiply(y_logistic, X_logistic @ b_logistic)
    - cp.logistic(X_logistic @ b_logistic)
)
prob_logistic = cp.Problem(cp.Maximize(log_likelihood))
prob_logistic.solve(solver=cp.SCS)
end_time = time.time()
cvxpy_logistic_time = end_time - start_time
print(f"cvxpy logistic regression time: {cvxpy_logistic_time:.6f} seconds")

cvxpy logistic regression time: 14.463672 seconds


### statsmodels

does IRLS

In [17]:
start_time = time.time()
sm_logit_res = sm.Logit(y_logistic, X_logistic).fit().params
end_time = time.time()
sm_logistic_time = end_time - start_time
print(f"statsmodels logistic regression time: {sm_logistic_time:.6f} seconds")

Optimization terminated successfully.
         Current function value: 0.425222
         Iterations 7
statsmodels logistic regression time: 0.373359 seconds


In [18]:
start_time = time.time()
X_jnp, y_jnp = jnp.array(X_logistic), jnp.array(y_logistic)

def logistic_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    h = jax.scipy.special.expit(z)
    loss = -jnp.sum(y_jnp * jnp.log(h) + (1 - y_jnp) * jnp.log1p(-h))
    return loss

params = jnp.array(linear_start)
solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(logistic_likelihood)

# Optimization loop
for i in range(10):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=logistic_likelihood
    )
    params = optax.apply_updates(params, updates)

end_time = time.time()
jax_logistic_time = end_time - start_time
print(f"jax logistic regression time: {jax_logistic_time:.6f} seconds")

jax logistic regression time: 3.161464 seconds


## comparison

In [19]:
logit_df = pd.DataFrame(
    {
        "true": true_params_logistic,
        "ensmallen": result_logistic_ens,
        "scipy": result_logistic_scipy,
        "cvxpy": b_logistic.value,
        "jax": params,
        "sm_logit": sm_logit_res,
    }
)

logit_times = {
    "ensmallen": ensmallen_logistic_time,
    "scipy": scipy_logistic_time,
    "cvxpy": cvxpy_logistic_time,
    "jax": jax_logistic_time,
    "sm_logit": sm_logistic_time,
}

logit_df = pd.DataFrame(logit_df).T
logit_df["time"] = logit_df.index.map(logit_times)
logit_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,time
true,0.379716,0.766157,0.974999,0.415148,0.729092,0.494284,0.141849,0.215497,0.663519,0.623577,...,0.321949,0.507764,0.611013,0.802985,0.611545,0.053343,0.174959,0.083817,0.105348,
ensmallen,0.392438,0.769003,0.971823,0.41394,0.736091,0.502503,0.156761,0.221206,0.677057,0.641259,...,0.311772,0.509763,0.595544,0.8128,0.626395,0.047291,0.174555,0.08382,0.110272,0.138552
scipy,0.392438,0.769003,0.971823,0.41394,0.736091,0.502503,0.156761,0.221206,0.677057,0.641259,...,0.311772,0.509763,0.595544,0.8128,0.626395,0.047291,0.174555,0.08382,0.110272,0.783805
cvxpy,0.392435,0.768999,0.971818,0.413938,0.736086,0.5025,0.15676,0.221204,0.677053,0.641255,...,0.31177,0.50976,0.595541,0.812795,0.626391,0.04729,0.174554,0.08382,0.110271,14.463672
jax,0.386978,0.79077,0.98746,0.418043,0.755707,0.494496,0.113446,0.224108,0.654783,0.602921,...,0.311,0.500431,0.612805,0.829732,0.640011,0.030524,0.165428,0.029116,0.108533,3.161464
sm_logit,0.392438,0.769003,0.971823,0.41394,0.736091,0.502503,0.156761,0.221206,0.677057,0.641259,...,0.311772,0.509763,0.595544,0.8128,0.626395,0.047291,0.174555,0.08382,0.110272,0.373359


Ensmallen is fastest among all libraries.

## Poisson Regression

In [20]:
n, k = 100_000, 10
# Poisson Regression Data
X_poisson = np.random.randn(n, k)
print(true_params_poisson := np.random.rand(k))
lambda_ = np.exp(X_poisson @ true_params_poisson)
y_poisson = np.random.poisson(lambda_)

[0.23958574 0.94348091 0.55742479 0.3897755  0.91983041 0.73762105
 0.92259406 0.70978266 0.77998603 0.8961713 ]


## pyensmallen

In [21]:
def poisson_objective(params, gradient, X, y):
    params = params.reshape(-1, 1)
    y = y.reshape(-1, 1)
    Xbeta = X @ params
    lambda_ = np.exp(Xbeta)
    objective = np.sum(lambda_ - np.multiply(y, np.log(lambda_)))
    # Compute the gradient
    grad = X.T @ (lambda_ - y)
    gradient[:] = grad.ravel()
    return objective

poisson_start = np.random.rand(k)

In [22]:
start_time = time.time()
optimizer = pyensmallen.L_BFGS()
result_poisson_ens = optimizer.optimize(
    lambda params, gradient: poisson_objective(params, gradient, X_poisson, y_poisson),
    poisson_start,
)
end_time = time.time()
ensmallen_poisson_time = end_time - start_time
print(f"pyensmallen poisson regression time: {ensmallen_poisson_time:.6f} seconds")

pyensmallen poisson regression time: 0.166607 seconds


### scipy

In [23]:
start_time = time.time()
result_poisson_scipy = scipy.optimize.minimize(
    fun=lambda b: np.sum(np.exp(X_poisson @ b) - y_poisson * (X_poisson @ b)),
    x0=poisson_start,
    jac=lambda b: X_poisson.T @ (np.exp(X_poisson @ b) - y_poisson),
).x
end_time = time.time()
scipy_poisson_time = end_time - start_time
print(f"scipy poisson regression time: {scipy_poisson_time:.6f} seconds")

scipy poisson regression time: 0.315395 seconds


### cvxpy

fails

In [25]:
# start_time = time.time()
# b_poisson = cp.Variable(k)
# z = X_poisson @ b_poisson
# cost_poisson = cp.sum(cp.exp(z) - cp.multiply(y_poisson, z)) / n
# prob_poisson = cp.Problem(cp.Minimize(cost_poisson))
# prob_poisson.solve(solver=cp.SCS)

# end_time = time.time()
# cvxpy_poisson_time = end_time - start_time
# print(f"cvxpy poisson regression time: {cvxpy_poisson_time:.6f} seconds")

Interrupted after 5 mins.

### statsmodels

In [24]:
start_time = time.time()
sm_poisson_res = sm.Poisson(y_poisson, X_poisson).fit().params
end_time = time.time()
sm_poisson_time = end_time - start_time
print(f"sm poisson regression time: {sm_poisson_time:.6f} seconds")

         Current function value: 3947169750512799336144193892850757803797908353139605504.000000
         Iterations: 35
sm poisson regression time: 0.563872 seconds




## jax
switch to adam since lbfgs performs poorly (!TODO why?)

In [25]:
%%time
start_time = time.time()
X_jnp, y_jnp = jnp.array(X_poisson), jnp.array(y_poisson)

def poisson_likelihood(beta):
    z = jnp.dot(X_jnp, beta)
    lambda_ = jnp.exp(z)
    loss = jnp.sum(lambda_ - y_jnp * z)
    return loss


solver = optax.adam(1e-2)
adam_params = jnp.array(poisson_start)
opt_state = solver.init(adam_params)

@jax.jit
def update_step(params, opt_state):
    loss, grads = jax.value_and_grad(poisson_likelihood)(params)
    updates, opt_state = solver.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


for i in range(1000):
    adam_params, opt_state, loss = update_step(adam_params, opt_state)

end_time = time.time()
jax_poisson_time = end_time - start_time
print(f"jax poisson regression time: {jax_poisson_time:.6f} seconds")

jax poisson regression time: 2.062514 seconds
CPU times: user 2.82 s, sys: 116 ms, total: 2.94 s
Wall time: 2.06 s


In [26]:
poi_df = pd.DataFrame(
    {
        "true": true_params_poisson,
        "ensmallen": result_poisson_ens,
        "scipy": result_poisson_scipy,
        "sm_poisson": sm_poisson_res,
        # "cvxpy": b_poisson.value,
        "jax": adam_params,
    }
)


poi_times = {
    "ensmallen": ensmallen_poisson_time,
    "scipy": scipy_poisson_time,
    "jax": jax_poisson_time,
    # "cvxpy": cvxpy_poisson_time,
    "sm_poisson": sm_poisson_time,
}

poi_df = pd.DataFrame(poi_df).T
poi_df["time"] = poi_df.index.map(poi_times)
poi_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,time
true,0.239586,0.943481,0.557425,0.389775,0.91983,0.737621,0.922594,0.709783,0.779986,0.896171,
ensmallen,0.239864,0.944432,0.558993,0.390959,0.919913,0.737666,0.922831,0.708114,0.778613,0.895464,0.166607
scipy,0.239864,0.944432,0.558993,0.390959,0.919913,0.737666,0.922831,0.708114,0.778613,0.895464,0.315395
sm_poisson,6.231685,9.539198,6.07923,5.737063,11.695746,9.101099,11.501975,8.929547,11.217799,9.636084,0.563872
jax,0.239864,0.944432,0.558993,0.390959,0.919913,0.737666,0.922831,0.708114,0.778613,0.895464,2.062514


Ensmallen is fastest again. CVXPY fails to converge after 5 minutes, as does statsmodels.