We need to add a maximum iteration count to the forward solver otherwise SBL becomes very slow.

In [1]:
# %% Imports
import jax
from jax import jit, numpy as jnp, lax
from functools import partial

In [36]:
@partial(jit, static_argnums=(0,))
def fwd_solver(f, z_init, tol=1e-4, max_iter=300):
    def cond_fun(carry):
        iteration, z_prev, z = carry
        # we check the change in alpha (element 0 in z tuple)
        # and the maximum number of iterations
        cond_norm = jnp.linalg.norm(z_prev[:-1] - z[:-1]) < tol
        cond_iter = iteration >= max_iter
        return ~jnp.logical_or(cond_norm, cond_iter)

    def body_fun(carry):
        iteration, _, z = carry
        return iteration + 1, z, f(z)

    init_carry = (0, z_init, f(z_init))  # first arg is iteration count
    _, _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

In [37]:
def update(prior_params, X, y, eigvals, norms, hyper_prior_params):
    # Unpacking parameters
    alpha_prev, beta_prev = prior_params
    a, b = hyper_prior_params

    # Calculating intermediate matrices
    n_samples, _ = X.shape
    gamma_ = jnp.sum(
        (beta_prev * eigvals) / (alpha_prev * norms + beta_prev * eigvals)
    )
    S = jnp.linalg.inv(beta_prev * X.T @ X + jnp.diag(alpha_prev * norms))
    mn = beta_prev * S @ X.T @ y

    # Update estimate
    alpha = gamma_ / jnp.sum(norms[:, None] * (mn ** 2))
    
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((y - X @ mn) ** 2) + 2 * b)

    return (alpha, beta)

In [4]:
# loading testdata
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

X_normed = X / jnp.linalg.norm(X, axis=0)

X = X_normed

In [5]:
eigvals = jnp.linalg.eigvalsh(X.T @ X)
norms = jnp.linalg.norm(X, axis=0)
prior_params_init = (1.0, 1.0 / jnp.var(y))
hyper_prior_params = (0.0, 0.0)
tol=1e-4
max_iter = 5

In [6]:
partial_update_fn = lambda z: update(z, X, y, eigvals, norms, hyper_prior_params)

In [7]:
with jax.disable_jit():
    prior_params = fwd_solver(partial_update_fn, prior_params_init, tol=tol, max_iter=300)

In [8]:
prior_params

(DeviceArray(0.3566604, dtype=float32), DeviceArray(273403.28, dtype=float32))

In [9]:
prior_params = fwd_solver(partial_update_fn, prior_params_init, tol=tol, max_iter=2)

In [10]:
prior_params

(DeviceArray(0.3566477, dtype=float32), DeviceArray(273403.06, dtype=float32))

In [38]:
@jit
def bayesianregression(
    X, y, prior_params_init=None, hyper_prior_params=jnp.zeros((2, )), tol=1e-4, max_iter=300
):
    def update(prior_params, X, y, eigvals, hyper_prior_params):
        # Unpacking parameters
        alpha_prev, beta_prev = prior_params[:-1], prior_params[-1]
        a, b = hyper_prior_params

        # Calculating intermediate matrices
        n_samples, n_terms = X.shape
        gamma_ = jnp.sum((beta_prev * eigvals) / (alpha_prev + beta_prev * eigvals))
        S = jnp.linalg.inv(
            beta_prev * X.T @ X + alpha_prev * jnp.eye(n_terms)
        )  # remove inverse?
        mn = beta_prev * S @ X.T @ y

        # Update estimate
        alpha = gamma_ / jnp.sum(mn ** 2)
        beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((y - X @ mn) ** 2) + 2 * b)
        
        return jnp.stack([alpha, beta])

    # Constructing update function.
    X_normed = X / jnp.linalg.norm(X, axis=0)
    eigvals = jnp.linalg.eigvalsh(X_normed.T @ X_normed)

    if prior_params_init is None:
        prior_params_init = jnp.stack([1.0, 1.0 / jnp.var(y)])

    # Calculating optimal prior
    prior_params = fixed_point_solver(
        update, (X_normed, y, eigvals, hyper_prior_params), prior_params_init, tol=tol, max_iter=max_iter
    )
    
    return prior_params

In [43]:
bayesianregression(X, y)

DeviceArray([3.5665095e-01, 2.7340316e+05], dtype=float32)

In [44]:
jax.grad(lambda x: bayesianregression(x, y)[0])(X)

DeviceArray([[-2.4349557e-04,  4.5630853e-03, -1.4199648e-02, ...,
               9.4683617e-03,  2.4732723e-04,  3.9888197e-03],
             [-1.0425382e-04,  1.4725944e-03, -1.6905170e-02, ...,
               4.2636171e-03,  1.3767934e-03, -3.2405937e-04],
             [-1.9567695e-05, -4.3725586e-04, -1.9224431e-02, ...,
               1.1115847e-03,  2.1444408e-03, -3.0756784e-03],
             ...,
             [-4.7815830e-04,  7.0107328e-03, -7.3108636e-02, ...,
               1.9427489e-02,  5.7208943e-03, -4.6256278e-04],
             [-2.7099944e-04,  3.1487744e-03, -6.0116306e-02, ...,
               1.1369799e-02,  5.4335683e-03, -3.7544412e-03],
             [-8.9526795e-05, -4.3830590e-04, -5.3410657e-02, ...,
               4.3993918e-03,  5.7278667e-03, -7.5057046e-03]],            dtype=float32)