In [1]:
import numpy as np
import jax.numpy as jnp
import lineax as lx
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression


In [2]:
n, p, η = 10_000, 20_000, 0.1
X = np.c_[np.repeat(1, n),
        np.random.normal(size = n*p).reshape((n, p))
    ]
β, nzcount = np.repeat(0.0, p + 1), int(η * p)
nzid = np.random.choice(p, nzcount, replace=False)
β[nzid] = np.random.randn(nzcount)
e = np.random.normal(0, 0.5 + (0.1 * X[:, 1]>0), n) # heteroskedasticity
y = X @ β + e


Very fast least squares solver (including for minimum norm interpolation problems). 


In [3]:
%%time
sol = lx.linear_solve(
        lx.MatrixLinearOperator(jnp.array(X)),
        jnp.array(y),
        solver=lx.AutoLinearSolver(well_posed=None),
    )

betahat = sol.value
# does it interpolate
(y - X @ betahat).max()


CPU times: user 3min 2s, sys: 24.6 s, total: 3min 27s
Wall time: 20 s


Array(0.0001545, dtype=float32)

In [7]:
np.linalg.norm(betahat)


31.983023

In [4]:
%%time
m = LinearRegression()
m.fit(X, y)


CPU times: user 32min 29s, sys: 43.4 s, total: 33min 12s
Wall time: 3min 17s


In [8]:
np.linalg.norm(m.coef_)


31.992084478126035