# Linear Model Comparison: statsmodels, pyfixest, and jaxonometrics

This notebook compares the performance of three different libraries for linear regression: `statsmodels`, `pyfixest`, and `jaxonometrics`. We will use a high-dimensional sparse DGP to highlight the differences in performance.

In [1]:
import warnings

import numpy as np
import pandas as pd

import jax

import jax.numpy as jnp
import statsmodels.api as sm
from pyfixest.estimation import feols
from jaxonometrics.linear import LinearRegression

In [2]:
jax.devices()

INFO:2025-06-29 15:23:13,113:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


[CpuDevice(id=0)]

In [3]:
a = jnp.array([1, 2, 3])
a.device

CpuDevice(id=0)

## DGP

In [4]:
def sparse_dgp(n=10_000, p=20_000, eta=0.2):
    X = np.c_[np.ones(n), np.random.normal(size=n * p).reshape((n, p))]
    β, nzcount = np.repeat(0.0, p + 1), int(eta * p)
    # eta% coefs are non-zero
    nzid = np.random.choice(p, nzcount, replace=False)
    β[nzid] = np.random.uniform(-2, 3, nzcount)
    # heteroskedasticity
    e = np.random.normal(0, 0.5 + (0.2 * X[:, 1] > 0), n)
    y = X @ β + e
    return y, X, β

## Low-Dim

In [5]:
n, p = 10_000, 20

### statsmodels

In [6]:
%%timeit -n 100 -r 5
y, X, beta = sparse_dgp(n, p)
m = sm.OLS(y, X).fit()

24.5 ms ± 930 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)


### pyfixest

In [7]:
%%timeit -n 100 -r 5
y, X, beta = sparse_dgp(n, p)
df = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
df["y"] = y
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fixest_model = feols(
        f"y ~ -1 + {'+'.join([f'x{i}' for i in range(X.shape[1])])}", data=df
    )

126 ms ± 2.79 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)
Compiler time: 0.19 s


### jaxonometrics

#### lineax backend 

In [8]:
%%timeit -n 100 -r 5
y, X, beta = sparse_dgp(n, p)
Xj, yj = jnp.array(X), jnp.array(y)
LinearRegression(solver="lineax").fit(Xj, yj, se="HC1")

9.56 ms ± 1.5 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)


#### jax linalg solve backend

In [9]:
%%timeit -n 100 -r 5
y, X, beta = sparse_dgp(n, p)
Xj, yj = jnp.array(X), jnp.array(y)
LinearRegression(solver="jax").fit(Xj, yj, se="HC1")

12.8 ms ± 1.03 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)


#### numpy backend

In [10]:
%%timeit -n 100 -r 5
y, X, beta = sparse_dgp(n, p)
LinearRegression(solver="numpy").fit(X, y, se="HC1")

8.71 ms ± 782 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)


## High-Dim

In [11]:
n, p = 10_000, 1_000

### pyfixest

In [12]:
%%timeit -n 10 -r 5
y, X, beta = sparse_dgp(n, p)
df = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
df["y"] = y
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fixest_model = feols(
        f"y ~ -1 + {'+'.join([f'x{i}' for i in range(X.shape[1])])}", data=df
    )


1.42 s ± 28.3 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


### jaxonometrics

#### lineax backend 

In [13]:
%%timeit -n 10 -r 5
y, X, beta = sparse_dgp(n, p)
Xj, yj = jnp.array(X), jnp.array(y)
LinearRegression(solver="lineax").fit(Xj, yj, se="HC1")

1.6 s ± 60 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


#### jax-numpy linalg solve backend

In [14]:
%%timeit -n 10 -r 5
y, X, beta = sparse_dgp(n, p)
Xj, yj = jnp.array(X), jnp.array(y)
LinearRegression(solver="jax").fit(Xj, yj, se="HC1")

2.13 s ± 669 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


### statsmodels

slows to a crawl.

In [None]:
%%timeit -n 10 -r 5
y, X, beta = sparse_dgp(n, p)
m = sm.OLS(y, X).fit()