We can make some efficiency changes to our bayes, mostly precalculating some arrays. we do that here:

In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from functools import partial
from jax import lax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from flax import linen as nn
from typing import Tuple
from code import fwd_solver, bayes_ridge_update, evidence, BayesianRegression, fixed_point_solver
from sklearn.linear_model import BayesianRidge
from jax.scipy.stats import gamma
%load_ext autoreload
%autoreload 2

In [2]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

Let's first get our baseline:

In [3]:
n_samples = X.shape[0]
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)

In [4]:
prior_init = jnp.stack([1., 1. / jnp.var(y)], axis=0)
update = jax.jit(lambda prior: bayes_ridge_update(prior_params=prior, y=y, X=X, hyper_prior_params=hyper_prior))
_ = update(prior_init) # compiling

In [5]:
%%timeit
z_star = fwd_solver(update, prior_init, tol=1e-4)

2.67 ms ± 275 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
key = random.PRNGKey(42)
model = BayesianRegression(hyper_prior, tol=1e-4)
inputs = (y, X)
variables = model.init(key, inputs)

In [7]:
update = jax.jit(lambda variables, inputs: model.apply(variables, inputs, mutable=['bayes']))
_ = update(variables, inputs)

In [8]:
%%timeit
y, updated_state = update(variables, inputs)

2.34 ms ± 94.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Now let's make a more efficient update by precalculating the gram, normalization and eigvals:

In [11]:
def bayes_ridge_update_efficient(prior_params, y, X, gram, eigvals, hyper_prior_params):
    # Unpacking parameters
    alpha_prev, beta_prev = prior_params
    a, b = hyper_prior_params
  
    # Calculating intermediate matrices
    n_samples, n_terms = X.shape
    gamma_ = jnp.sum((beta_prev * eigvals) / (alpha_prev + beta_prev * eigvals))
    S = jnp.linalg.inv(beta_prev * gram + alpha_prev * jnp.eye(n_terms))
    mn = beta_prev * S @ X.T @ y

    # Update estimate
    alpha = gamma_ / jnp.sum(mn ** 2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (
        jnp.sum((y - X @ mn) ** 2) + 2 * b
    )

    return jnp.stack([alpha, beta], axis=0)

In [21]:
prior_init = jnp.stack([1., 1. / jnp.var(y)], axis=0)
X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)
gram = X_normed.T @ X_normed
eigvals = jnp.linalg.eigvalsh(gram)

In [22]:
update = jax.jit(lambda prior: bayes_ridge_update_efficient(prior_params=prior, y=y, X=X_normed, gram=gram, eigvals=eigvals, hyper_prior_params=hyper_prior))
_ = update(prior_init) # compiling

In [23]:
layer = jax.jit(fwd_solver, static_argnums=(0, ))

In [24]:
%%timeit
z_star = layer(update, prior_init, tol=1e-4)

1.07 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


So we cut it down with a factor of two :-). Now let's put it nicely in a layer:

In [10]:
class BayesianRegressionEfficient(nn.Module):
    hyper_prior: Tuple
    tol: float = 1e-3

    def setup(self):
        self.update = lambda prior, y, X, gram, eigvals: bayes_ridge_update_efficient(
            prior, y, X, gram, eigvals, self.hyper_prior
        )

    @nn.compact
    def __call__(self, inputs):
        is_initialized = self.has_variable("bayes", "z")
        z_init = self.variable(
            "bayes", "z", lambda y: jnp.stack([1, 1 / jnp.var(y)], axis=0), inputs[0]
        )
        
        y, X = inputs
        X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)
        gram = X_normed.T @ X_normed
        eigvals = jnp.linalg.eigvalsh(gram)
        
        z_star = fixed_point_solver(self.update, (y, X_normed, gram, eigvals), z_init.value, tol=self.tol)
        if is_initialized:
            z_init.value = z_star
        return z_star


In [12]:
key = random.PRNGKey(42)
model = BayesianRegressionEfficient(hyper_prior, tol=1e-4)
inputs = (y, X)
variables = model.init(key, inputs)

In [14]:
update = jax.jit(lambda variables, inputs: model.apply(variables, inputs, mutable=['bayes']))
_ = update(variables, inputs)

In [15]:
%%timeit
y, updated_state = update(variables, inputs)

1.45 ms ± 6.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
