# Linear Model Comparison: statsmodels, pyfixest, and jaxonometrics

This notebook compares the performance of three different libraries for linear regression: `statsmodels`, `pyfixest`, and `jaxonometrics`. We will use a high-dimensional sparse DGP to highlight the differences in performance.

In [1]:
import numpy as np
import pandas as pd
import jax.numpy as jnp
import statsmodels.api as sm
from pyfixest.estimation import feols
from jaxonometrics.linear import LinearRegression
import time

## DGP

In [2]:
def sparse_dgp(n=10_000, p=20_000, eta=0.1):
    X = np.c_[np.ones(n), np.random.normal(size=n * p).reshape((n, p))]
    β, nzcount = np.repeat(0.0, p + 1), int(eta * p)
    nzid = np.random.choice(p, nzcount, replace=False)
    β[nzid] = np.random.randn(nzcount)
    e = np.random.normal(0, 0.5 + (0.1 * X[:, 1] > 0), n)
    y = X @ β + e
    return y, X

In [3]:
def onerep(X, y):
    # df
    df = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
    df["y"] = y
    # statsmodels
    start1 = time.time()
    sm_model = sm.OLS(y, X).fit()
    sm_time = time.time() - start1
    # pyfixest
    start2 = time.time()
    fixest_model = feols(
        f"y ~ -1 + {'+'.join([f'x{i}' for i in range(X.shape[1])])}", data=df
    )
    fixest_time = time.time() - start2
    # jax
    start3 = time.time()
    jax_model = LinearRegression()
    jax_model.fit(jnp.array(X), jnp.array(y))
    jax_time = time.time() - start3
    return {
        "sm_time": sm_time,
        "fixest_time": fixest_time,
        "jax_time": jax_time,
        "sm_params": sm_model.params,
        "fixest_params": fixest_model.coef().values,
        "jax_params": np.array(jax_model.params["beta"]),
    }

### Low-Dim

In [4]:
y, X = sparse_dgp(n = 10_000, p=20)
res = onerep(X, y)
print(f" sm time: {res['sm_time']:.4f}s, fixest time: {res['fixest_time']:.4f}s, jax time: {res['jax_time']:.4f}s")
res["sm_params"][:5], res["fixest_params"][:5], res["jax_params"][:5]

            1 variables dropped due to multicollinearity.
            The following variables are dropped: ['x0'].
            
INFO:2025-06-28 22:16:08,080:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


 sm time: 0.0117s, fixest time: 1.0715s, jax time: 0.3841s


(array([-0.00219642,  0.01010658,  0.01650456,  0.00361118,  0.01222154]),
 array([-0.00219642,  0.01010658,  0.01650456,  0.00361118,  0.01222154]),
 array([-0.00219641,  0.01010658,  0.01650455,  0.00361117,  0.01222158],
       dtype=float32))

### High-Dim

In [5]:
y, X = sparse_dgp(n = 10_000, p=2_000)
res = onerep(X, y)
print(f" sm time: {res['sm_time']:.4f}s, fixest time: {res['fixest_time']:.4f}s, jax time: {res['jax_time']:.4f}s")
res["sm_params"][:5], res["fixest_params"][:5], res["jax_params"][:5]

            1 variables dropped due to multicollinearity.
            The following variables are dropped: ['x0'].
            


 sm time: 14.8658s, fixest time: 14.2712s, jax time: 2.2547s


(array([ 0.00146385, -0.00174201, -0.01347367, -0.01551777, -0.02014805]),
 array([ 0.00146385, -0.00174201, -0.01347367, -0.01551777, -0.02014805]),
 array([ 0.00146354, -0.00174188, -0.01347361, -0.01551812, -0.02014804],
       dtype=float32))

### Ultra-High-Dim / Ill-Posed

The minimum-norm interpolator is easy to calculate with lineax because it automatically adapts the solver.

In [None]:
y, X = sparse_dgp(n = 10_000, p=20_000)
res = onerep(X, y)
print(res["sm_time"], res["fixest_time"], res["jax_time"])
res["sm_params"][:5], res["fixest_params"][:5], res["jax_params"][:5]