In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import pyensmallen
import time
import optax

# Set random seed for reproducibility
np.random.seed(0)
key = jax.random.PRNGKey(0)

In [2]:
# Set the parameters
K = 4  # number of classes
D = 10  # number of features
N = 10_000  # number of samples

# Generate true coefficients (K categories, last category is reference with zeros)
true_coeffs = np.random.randn(D, K)
true_coeffs[:, -1] = 0  # Set last category coefficients to zero

# Generate features
X = np.random.randn(N, D)

# Generate probabilities and labels
logits = X @ true_coeffs
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
y = np.array([np.random.choice(K, p=p) for p in probs])

# Convert data to JAX arrays
X_jax = jax.device_put(X)
y_jax = jax.device_put(y)

## pyensmallen + jax

In [3]:
# Define the multinomial logistic regression model
def multinomial_logit(params, X):
    full_params = jnp.column_stack([params.reshape(D, K - 1), jnp.zeros((D, 1))])
    return jax.nn.log_softmax(X @ full_params, axis=1)


# Define the loss function (negative log-likelihood)
def loss(params, X, y):
    logits = multinomial_logit(params, X)
    return -jnp.mean(logits[jnp.arange(y.shape[0]), y])


# Create JAX gradient function - autodiff!
grad_loss = jax.grad(loss)


# Define the objective function for pyensmallen
def objective(params, gradient, X, y):
    params_jax = jax.device_put(params.reshape(D, K - 1))
    loss_value = loss(params_jax, X_jax, y_jax)
    grad = grad_loss(params_jax, X_jax, y_jax)
    gradient[:] = np.array(grad).flatten()
    return float(loss_value)


# Pyensmallen optimization
start_time = time.time()
optimizer = pyensmallen.L_BFGS()
initial_params = np.random.randn(D * (K - 1))
result_ens = optimizer.optimize(
    lambda params, gradient: objective(params, gradient, X_jax, y_jax), initial_params
)
ens_time = time.time() - start_time
estimated_coeffs_ens = np.column_stack([result_ens.reshape(D, K - 1), np.zeros((D, 1))])

## Jax

In [4]:
# JAX optimization with Optax
start_time = time.time()
initial_params = jnp.array(initial_params.reshape(D, K - 1))

# Define the Optax optimizer (using Adam as an example)
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(initial_params)


@jax.jit
def step(params, opt_state, X, y):
    loss_value, grads = jax.value_and_grad(loss)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value


params = initial_params
for i in range(2000):
    params, opt_state, _ = step(params, opt_state, X_jax, y_jax)

estimated_coeffs_jax = jnp.column_stack([params, jnp.zeros((D, 1))])
jax_time = time.time() - start_time

## comparison

In [5]:
true_coeffs.reshape(-1).shape

(40,)

In [6]:
np.c_[true_coeffs.reshape(-1), estimated_coeffs_ens.reshape(-1), estimated_coeffs_jax.reshape(-1)]

array([[ 1.76405235,  1.81593705,  1.81090689],
       [ 0.40015721,  0.42660712,  0.42548722],
       [ 0.97873798,  0.9881251 ,  0.98618811],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.86755799,  1.86311812,  1.85787034],
       [-0.97727788, -0.99712363, -0.99704129],
       [ 0.95008842,  0.91224389,  0.91011459],
       [ 0.        ,  0.        ,  0.        ],
       [-0.10321885, -0.09046709, -0.09009396],
       [ 0.4105985 ,  0.40371618,  0.40340352],
       [ 0.14404357,  0.1855722 ,  0.18550713],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.76103773,  0.8096924 ,  0.80794239],
       [ 0.12167502,  0.16937907,  0.16907169],
       [ 0.44386323,  0.42813934,  0.42753083],
       [ 0.        ,  0.        ,  0.        ],
       [ 1.49407907,  1.5803887 ,  1.57713163],
       [-0.20515826, -0.13466159, -0.1348491 ],
       [ 0.3130677 ,  0.33455137,  0.33372569],
       [ 0.        ,  0.        ,  0.        ],
       [-2.55298982, -2.57077101, -2.563

In [7]:
# Compare results
print("Pyensmallen optimization time:", ens_time)
print("JAX optimization time:", jax_time)

mae_ens = np.mean(np.abs(true_coeffs - estimated_coeffs_ens))
mae_jax = np.mean(np.abs(true_coeffs - estimated_coeffs_jax))

print("\nPyensmallen Mean Absolute Error:", mae_ens)
print("JAX Mean Absolute Error:", mae_jax)

Pyensmallen optimization time: 1.0235660076141357
JAX optimization time: 2.4678776264190674

Pyensmallen Mean Absolute Error: 0.026384408255362625
JAX Mean Absolute Error: 0.025860388


In [8]:
def predict(coeffs, X):
    logits = X @ coeffs
    return np.argmax(logits, axis=1)


accuracy_ens = np.mean(predict(estimated_coeffs_ens, X) == y)
accuracy_jax = np.mean(predict(estimated_coeffs_jax, X) == y)

print("\nPyensmallen Accuracy:", accuracy_ens)
print("JAX Accuracy:", accuracy_jax)

final_loss_ens = loss(jax.device_put(estimated_coeffs_ens[:, :-1]), X_jax, y_jax)
final_loss_jax = loss(estimated_coeffs_jax[:, :-1], X_jax, y_jax)

print("\nPyensmallen Final Loss:", final_loss_ens)
print("JAX Final Loss:", final_loss_jax)


Pyensmallen Accuracy: 0.7667
JAX Accuracy: 0.76669997

Pyensmallen Final Loss: 0.5785731
JAX Final Loss: 0.57857406
