In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from flax import optim
from modax.models import Deepmod
from modax.training import create_update
from modax.losses import loss_fn_pinn
from modax.logging import Logger

from sklearn.linear_model import BayesianRidge
from jax.scipy.stats import gamma
from modax.data.burgers import burgers
from time import time

from functools import partial
from jax import lax

In [2]:
# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)

In [3]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [4]:
# Running to convergence
max_epochs = 10001
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.060140594840049744
Loss step 1000: 3.848650885629468e-05
Loss step 2000: 1.3884466625313507e-06
Loss step 3000: 4.39812481545232e-07
Loss step 4000: 4.0220501773546857e-07
Loss step 5000: 3.040478304683347e-07
Loss step 6000: 2.48571353722582e-07
Loss step 7000: 2.749341661001381e-07
Loss step 8000: 1.2674520633026987e-07
Loss step 9000: 1.1675654576492889e-07
Loss step 10000: 1.1416497613936372e-07


In [5]:
prediction, dt, theta, coeffs = model.apply(optimizer.target, X_train)

In [None]:
# First normalize theta
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [None]:
# Quick check with OLS
jnp.linalg.lstsq(theta_normed, dt)[0]

In [None]:
# Calculate eigenvalues and set and initial b
a = 1
b = 1 / jnp.var(dt)

gram = theta_normed.T @ theta_normed

l = jnp.linalg.eigvalsh(gram)

n_samples = theta_normed.shape[0]
n_terms = theta_normed.shape[1]

In [None]:
# updating gamma
gamma = jnp.sum(b * l / (a + b * l))

In [None]:
# Calculating posterior mean
S = jnp.linalg.inv(b * gram + a * jnp.eye(n_terms))
mn = b * S @ theta_normed.T @ dt

print(mn)

In [None]:
# Updating a and b
a = gamma / mn.T @ mn
b = (n_samples - gamma) / jnp.sum((dt - theta_normed @ mn)**2)

In [None]:
a

In [None]:
b

That seems to work, now lets do it in a loop:

In [None]:
@jax.jit
def update(alpha_, beta_, lambda_, t, Phi, gram):
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]

    gamma_ = jnp.sum((beta_ * lambda_) / (alpha_ + beta_ * lambda_))
    S = jnp.linalg.inv(beta_ * gram + alpha_ * jnp.eye(n_terms)) # Change to QR?
    mn = beta_ * S @ Phi.T @ t
    
    alpha_new = gamma_ / (mn.T @ mn).squeeze()
    beta_new = (n_samples - gamma_) / jnp.sum((t - Phi @ mn)**2)
    
    return alpha_new, beta_new, mn

In [None]:
update(a, b, l, dt, theta_normed, gram)

In [None]:
(a, b, mn), a_old, b_old = update(a, b, l, dt, theta_normed, gram), a, b

In [None]:
while jnp.linalg.norm((a - a_old)) > 1e-4:
    (a, b, _), a_old, b_old = update(a, b, l, dt, theta_normed, gram), a, b
    print(b - b_old)

In [None]:
jnp.linalg.norm((b - b_old))

In [None]:
update(a, b, l, dt, theta_normed, gram)[2]

In [None]:
jnp.linalg.lstsq(theta_normed, dt)[0]

In [None]:
print(a)

In [None]:
reg = BayesianRidge(fit_intercept=False)

In [None]:
%%time
reg.fit(theta_normed, dt.squeeze()).coef_[:, None]

In [None]:
reg.alpha_

In [None]:
reg.lambda_

In [None]:
b

In [None]:
a

In [None]:
(a - reg.lambda_) / reg.lambda_

In [None]:
(b - reg.alpha_) / reg.alpha_ * 100

Now lets bring it all back together:

In [None]:
@jax.jit
def bayes_ridge_update(alpha_, beta_, lambda_, t, Phi, gram):
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]

    gamma_ = jnp.sum((beta_ * lambda_) / (alpha_ + beta_ * lambda_))
    S = jnp.linalg.inv(beta_ * gram + alpha_ * jnp.eye(n_terms)) # Change to QR?
    mn = beta_ * S @ Phi.T @ t
    
    alpha_new = gamma_ / (mn.T @ mn).squeeze()
    beta_new = (n_samples - gamma_) / jnp.sum((t - Phi @ mn)**2)
    
    return alpha_new, beta_new, mn

In [None]:
# Calculate eigenvalues and set and initial b
prediction, dt, theta, coeffs = model.apply(optimizer.target, X_train)

a = 1
b = 1 / jnp.var(dt)

theta_normed = theta / jnp.linalg.norm(theta, axis=0)
gram = theta_normed.T @ theta_normed
l = jnp.linalg.eigvalsh(gram)

In [None]:
update = jax.jit(partial(bayes_ridge_update, lambda_=l, t=dt, Phi=theta_normed, gram=gram))
(a, b, _), a_old, b_old = update(a, b), a, b

In [None]:
%%time
while jnp.linalg.norm((a - a_old)) > 1e-3:
    (a, b, mn), a_old, b_old = update(a, b), a, b

In [None]:
mn

We put the training loop in a nice function:

In [None]:
@jax.jit
def bayes_ridge_update(alpha_, beta_, lambda_, t, Phi, gram):
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]

    gamma_ = jnp.sum((beta_ * lambda_) / (alpha_ + beta_ * lambda_))
    S = jnp.linalg.inv(beta_ * gram + alpha_ * jnp.eye(n_terms)) # Change to QR?
    mn = beta_ * S @ Phi.T @ t
    
    alpha_new = gamma_ / (mn.T @ mn).squeeze()
    beta_new = (n_samples - gamma_) / jnp.sum((t - Phi @ mn)**2)
    
    return alpha_new, beta_new, mn

def bayesian_ridge(dt, theta, init_vals):
    a, b = init_vals
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    gram = theta_normed.T @ theta_normed
    l = jnp.linalg.eigvalsh(gram)

    # Making update function
    update = jax.jit(partial(bayes_ridge_update, lambda_=l, t=dt, Phi=theta_normed, gram=gram))
    (a, b, _), a_old, b_old = update(a, b), a, b
    
    # Running to convergence
    while jnp.linalg.norm((a - a_old)) > 1e-3:
        (a, b, mn), a_old, b_old = update(a, b), a, b
    mn = mn / jnp.linalg.norm(theta, axis=0)[:, None]
    return a, b, mn

In [None]:
%%time
a, b, mn = bayesian_ridge(dt, theta, init_vals=(1, 1 / jnp.var(dt)))
print(a, b, mn)

Let's calculate the loss:

In [None]:
def evidence(Phi, t, mn, alpha_, beta_):
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]
    
    A = alpha_ * jnp.eye(n_terms) + beta_ * Phi.T @ Phi
    E = beta_ / 2 * jnp.sum((t - Phi @ mn)**2) + alpha_ / 2 * mn.T @ mn
    p = n_terms / 2 * jnp.log(alpha_) + n_samples / 2 * jnp.log(beta_) - E - 1/2 * jnp.linalg.slogdet(A)[1] - n_samples / 2 * jnp.log(2 * jnp.pi)
    return p.squeeze()

In [None]:
evidence(theta, dt, mn, a, b)

Now lets include a prior on beta:

In [None]:
@jax.jit
def bayes_ridge_update(alpha_, beta_, lambda_, t, Phi, gram, prior_params):
    a, b = prior_params
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]

    gamma_ = jnp.sum((beta_ * lambda_) / (alpha_ + beta_ * lambda_))
    S = jnp.linalg.inv(beta_ * gram + alpha_ * jnp.eye(n_terms)) # Change to QR?
    mn = beta_ * S @ Phi.T @ t
    
    alpha_new = gamma_ / jnp.sum(mn**2)
    beta_new = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((t - Phi @ mn)**2) + 2 * b)
    
    return alpha_new, beta_new, mn


def bayesian_ridge(dt, theta, init_vals, prior_params):
    a, b = init_vals
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    gram = theta_normed.T @ theta_normed
    l = jnp.linalg.eigvalsh(gram)

    # Making update function
    update = jax.jit(partial(bayes_ridge_update, lambda_=l, t=dt, Phi=theta_normed, gram=gram, prior_params=prior_params))
    (a, b, _), a_old, b_old = update(a, b), a, b
    
    # Running to convergence
    while jnp.linalg.norm((a - a_old)) > 1e-4:
        (a, b, mn), a_old, b_old = update(a, b), a, b
    mn = mn / jnp.linalg.norm(theta, axis=0)[:, None]
    return a, b, mn

In [None]:
jax.jit(bayesian_ridge, static_argnums=(2, 3))(dt, theta, init, prior_params)

In [None]:
%%time
n_samples = theta.shape[0]
prior_params = (n_samples/2, 1/(n_samples/2 * 1e-4))
init = (1, 1 / jnp.var(dt))
a, b, mn = bayesian_ridge(dt, theta, init, prior_params)
print(a, b, mn)

Let's calculate the loss:

In [None]:
def evidence(Phi, t, mn, alpha_, beta_, prior_params):
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]
    a, b = prior_params
    
    A = alpha_ * jnp.eye(n_terms) + beta_ * Phi.T @ Phi
    E = beta_ / 2 * jnp.sum((t - Phi @ mn)**2) + alpha_ / 2 * mn.T @ mn
    log_p = n_terms / 2 * jnp.log(alpha_) + n_samples / 2 * jnp.log(beta_) - E - 1/2 * jnp.linalg.slogdet(A)[1] - n_samples / 2 * jnp.log(2 * jnp.pi)
    log_p += gamma.logpdf(beta_, a=a, scale=b)
    return log_p.squeeze()

In [None]:
evidence(theta, dt, mn, a, b, prior_params)

In [None]:
from typing import Sequence, Tuple
from modax.feature_generators import library_backward, library_forward
from modax.layers import LeastSquares, LeastSquaresMT
from modax.networks import MLP, MultiTaskMLP
from flax import linen as nn


class Deepmod(nn.Module):
    """Simple feed-forward NN.
    """

    features: Sequence[int]  # this is dataclass, so we dont use __init__
    prior_params: Tuple
    @nn.compact 
    def __call__(self, inputs):
        prediction, dt, theta = library_backward(MLP(self.features), inputs)
        a, b, coeffs = bayesian_ridge(dt, theta, (1, 1 / jnp.var(dt)), self.prior_params) 
        return prediction, dt, theta, coeffs, (a, b)

In [None]:
# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)

In [None]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1], prior_params=prior_params)
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

In [None]:
def loss_fn_pinn_bayes_reg(params, model, x, y):
    prediction, dt, theta, coeffs, fit_params = model.apply(params, x)
    n_samples = prediction.shape[0]
    a0, b0 = n_samples / 2, 1 / (n_samples / 2 * 1e-4)

    sigma_ml = jnp.mean((prediction - y) ** 2)
    tau = 1 / sigma_ml
    MSE = neg_LL(prediction, y, tau)

    Reg = -evidence(theta, dt, coeffs, fit_params[0], fit_params[1], (a0, b0))
    prior = -jnp.sum(gamma.logpdf(beta, a=a0, scale=b0))

    loss = MSE + Reg + prior

    metrics = {
        "loss": loss,
        "mse": MSE,
        "reg": Reg,
        "coeff": coeffs,
        "tau": tau,
        "beta": fit_params[1],
        "alpha": fit_params[0]
    }
    return loss, metrics

In [None]:
# Compiling train step
update = create_update(loss_fn_pinn_bayes_reg, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [None]:
# Running to convergence
max_epochs = 10001
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

# Implicit layers code

In [6]:
from functools import partial
import jax
from jax import random, lax, numpy as jnp

In [22]:
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))

In [23]:
def fwd_solver(f, z_init):
    z_prev, z = z_init, f(z_init)
    while jnp.linalg.norm(z_prev - z) > 1e-3:
        z_prev, z = z, f(z)
    return z

def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
    return z_star

In [24]:
f = lambda W, x, z: jnp.tanh(jnp.dot(W, z) + x)
z_star = fixed_point_layer(fwd_solver, f, W, x)
print(z_star)

[ 0.00632886 -0.70152855 -0.9847213  -0.0419194  -0.6151645  -0.48185453
  0.5783277   0.9556748  -0.08354193  0.8447265 ]


In [25]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
    return z_star

def fixed_point_layer_fwd(solver, f, params, x):
    z_star = fixed_point_layer(solver, f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(solver, f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=jnp.zeros_like(z_star)))

fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [26]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00733157 -0.81267565 -1.1407362  -0.04856092 -0.7126285  -0.55819744
  0.66995543  1.1070877  -0.09677795  0.9785612 ]


In [51]:
def fwd_solver(f, z_init):
    def cond_fun(carry):
        z_prev, z = carry
        return jnp.linalg.norm(z_prev - z) > 1e-5

    def body_fun(carry):
        _, z = carry
        return z, f(z)

    init_carry = (z_init, f(z_init))
    _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

In [31]:
%%time
z_star = fixed_point_layer(fwd_solver, f, W, x)
print(z_star)

[ 0.00649604 -0.7015957  -0.98471504 -0.04196557 -0.61522186 -0.48183814
  0.5783122   0.95567054 -0.08373152  0.8447805 ]


In [32]:
fp_layer_jit = jax.jit(fixed_point_layer, static_argnums=(0, 1))
fp_layer_jit(fwd_solver, f, W, x)

DeviceArray([ 0.00649604, -0.7015957 , -0.98471504, -0.04196557,
             -0.61522186, -0.48183814,  0.5783122 ,  0.95567054,
             -0.08373152,  0.8447805 ], dtype=float32)

In [33]:
%%timeit
z_star = fixed_point_layer(fwd_solver, f, W, x)

39.6 ms ± 35.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [34]:
%%timeit
z_star = fp_layer_jit(fwd_solver, f, W, x)

601 µs ± 254 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Now lets adapt our bayesian regression code to this:

In [7]:
def fwd_solver(f, z_init):
    def cond_fun(carry):
        z_prev, z = carry
        return jnp.linalg.norm(z_prev[0] - z[0]) > 1e-3

    def body_fun(carry):
        _, z = carry
        return z, f(z)

    init_carry = (z_init, f(z_init))
    _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

@jax.jit
def bayes_ridge_update(alpha_, beta_, lambda_, t, Phi, gram, prior_params):
    
    a, b = prior_params
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]

    gamma_ = jnp.sum((beta_ * lambda_) / (alpha_ + beta_ * lambda_))
    S = jnp.linalg.inv(beta_ * gram + alpha_ * jnp.eye(n_terms)) # Change to QR?
    mn = beta_ * S @ Phi.T @ t
    
    alpha_new = gamma_ / jnp.sum(mn**2)
    beta_new = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((t - Phi @ mn)**2) + 2 * b)
    
    return alpha_new, beta_new

In [65]:
# Getting prerequisites in order
n_samples = dt.shape[0]
init_vals = (1, 1 / jnp.var(dt))
prior_params = (n_samples/2, 1/(n_samples/2 * 1e-4))

theta_normed = theta / jnp.linalg.norm(theta, axis=0)
gram = theta_normed.T @ theta_normed
l = jnp.linalg.eigvalsh(gram)

In [66]:
update = lambda z: bayes_ridge_update(z[0], z[1], lambda_=l, t=dt, Phi=theta_normed, gram=gram, prior_params=prior_params)

In [67]:
%%timeit
fwd_solver(update, init_vals)

206 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


So it seems to work. Can we jit?

In [74]:
fwd_solver_jit = jax.jit(fwd_solver, static_argnums=(0))

In [76]:
%%timeit
fwd_solver_jit(update, init_vals)

1.18 ms ± 4.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Wooosh. Now we put it in a nice function:

In [81]:
def bayesian_ridge(dt, theta):
    # Getting prerequisites in order
    n_samples = dt.shape[0]
    init_vals = (1, 1 / jnp.var(dt))
    prior_params = (n_samples/2, 1/(n_samples/2 * 1e-4))

    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    gram = theta_normed.T @ theta_normed
    l = jnp.linalg.eigvalsh(gram)
    
    update = lambda z: bayes_ridge_update(z[0], z[1], lambda_=l, t=dt, Phi=theta_normed, gram=gram, prior_params=prior_params)
    return fwd_solver(update, init_vals)

bayesian_ridge_jit = jax.jit(bayesian_ridge)

In [83]:
%%timeit
bayesian_ridge(dt, theta)

219 ms ± 848 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [84]:
%%timeit
bayesian_ridge_jit(dt, theta)

1.48 ms ± 3.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Now we clean up the code a bit and bring it more in line with the jax implicit layers code:

In [51]:
def fwd_solver(f, z_init, tol=1e-3):
    def cond_fun(carry):
        z_prev, z = carry
        return jnp.linalg.norm(z_prev[0] - z[0]) > tol

    def body_fun(carry):
        _, z = carry
        return z, f(z)

    init_carry = (z_init, f(z_init))
    _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

@jax.jit
def bayes_ridge_update(params, lambda_, t, Phi, prior_params):
    alpha_old, beta_old, _ = params
    a, b = prior_params
    
    n_samples = Phi.shape[0]
    n_terms = Phi.shape[1]
    
    gamma_ = jnp.sum((beta_old * lambda_) / (alpha_old + beta_old * lambda_))
    S = jnp.linalg.inv(beta_old * Phi.T @ Phi + alpha_old * jnp.eye(n_terms)) # Change to QR?
    mn = beta_old * S @ Phi.T @ t
    
    alpha = gamma_ / jnp.sum(mn**2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((t - Phi @ mn)**2) + 2 * b)
    
    return alpha, beta, mn

@jax.jit
def bayesian_ridge(params, inputs, tol=1e-3):
    # Unpacking inputs
    prior_params = params
    dt, theta = inputs
    
    # preparing some useful 
    init_vals = (1, 1 / jnp.var(dt), jnp.zeros((theta.shape[1], 1)))
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    eigvals = jnp.linalg.eigvalsh(theta_normed.T @ theta_normed) # do we really need both phi and gram?
    
    # Running
    update = partial(bayes_ridge_update, t=dt, Phi=theta_normed, lambda_=eigvals, prior_params=prior_params)
    alpha_, beta_, mn = fwd_solver(update, init_vals, tol=tol)
    mn = mn / jnp.linalg.norm(theta, axis=0)[:, None] # dimensionalizing again
    return alpha_, beta_, mn

In [52]:
n_samples = dt.shape[0]
prior_params = ((n_samples/2, 1/(n_samples/2 * 1e-4)))

In [53]:
%%timeit
bayesian_ridge(prior_params, (dt, theta))

1.8 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
bayesian_ridge(prior_params, (dt, theta))

(DeviceArray(0.1628958, dtype=float32),
 DeviceArray(199.29915, dtype=float32),
 DeviceArray([[ 4.7251882e-04],
              [-1.8976321e-02],
              [ 9.8508134e-02],
              [ 4.7280577e-05],
              [-4.0016812e-03],
              [-9.0126008e-01],
              [ 5.3995108e-04],
              [ 6.2670169e-04],
              [ 4.3693967e-03],
              [-9.1061093e-02],
              [ 7.6659984e-04],
              [-5.9627282e-04]], dtype=float32))

In [98]:
def fwd_solver(f, z_init, tol=1e-3):
    def cond_fun(carry):
        z_prev, z = carry
        return jnp.linalg.norm(z_prev[0] - z[0]) > tol

    def body_fun(carry):
        _, z = carry
        return z, f(z)

    init_carry = (z_init, f(z_init))
    _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

def bayes_ridge_update(params, t, Phi, prior_params):
    Phi_normed = Phi / jnp.linalg.norm(theta, axis=0)
    eigvals = jnp.linalg.eigvalsh(Phi_normed.T @ Phi_normed) # do we really need both phi and gram?
    
    alpha_old, beta_old = params
    a, b = prior_params
    
    n_samples = Phi_normed.shape[0]
    n_terms = Phi_normed.shape[1]
    
    gamma_ = jnp.sum((beta_old * eigvals) / (alpha_old + beta_old * eigvals))
    S = jnp.linalg.inv(beta_old * Phi_normed.T @ Phi_normed + alpha_old * jnp.eye(n_terms)) # Change to QR?
    mn = beta_old * S @ Phi_normed.T @ t
    
    alpha = gamma_ / jnp.sum(mn**2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((t - Phi_normed @ mn)**2) + 2 * b)
    
    return jnp.ones((1, )) * alpha,  jnp.ones((1, )) * beta

@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def bayesian_ridge(prior_params, inputs):
    # Unpacking inputs
    dt, theta = inputs
    
    # preparing some useful 
    init_vals = (jnp.ones((1, )), jnp.ones((1, )) / jnp.var(dt))
    
    # Running
    update = partial(bayes_ridge_update, t=dt, Phi=theta, prior_params=prior_params)
    z_star = fwd_solver(update, init_vals, tol=1e-3)
    return z_star

def bayesian_ridge_fwd(prior_params, inputs):
    # Unpacking inputs
    dt, theta = inputs
    
    # preparing some useful 
    init_vals = (jnp.ones((1, )), jnp.ones((1, )) / jnp.var(dt))
    
    # Running
    update = partial(bayes_ridge_update, t=dt, Phi=theta, prior_params=prior_params)
    z_star = fwd_solver(update, init_vals, tol=1e-3)
    return z_star, (dt, theta, z_star)

def bayesian_ridge_bwd(prior_params, res, z_star_bar):
    dt, theta, z_star = res
    _, vjp_a = jax.vjp(lambda dt, Phi: bayes_ridge_update(params=z_star, t=dt, Phi=Phi, prior_params=prior_params), dt, theta)
    _, vjp_z = jax.vjp(lambda z: bayes_ridge_update(params=z, t=dt, Phi=theta, prior_params=prior_params), z_star)
    return vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=(jnp.zeros((1,)), jnp.zeros((1, )))))

bayesian_ridge.defvjp(bayesian_ridge_fwd, bayesian_ridge_bwd)

In [99]:
bayesian_ridge_jit = jax.jit(bayesian_ridge, static_argnums=0)

n_samples = dt.shape[0]
prior_params = ((jnp.ones((1, )) * n_samples/2, jnp.ones((1, )) /(n_samples/2 * 1e-4)))
inputs = (dt, theta)

In [100]:
prior_params

(DeviceArray([1000.], dtype=float32), DeviceArray([10.], dtype=float32))

In [101]:
%%timeit
bayesian_ridge_jit(prior_params, inputs)

2.6 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [102]:
bayesian_ridge_jit(prior_params, (dt, theta))

(DeviceArray([0.1629698], dtype=float32),
 DeviceArray([199.29936], dtype=float32))

In [103]:
f = lambda inputs: bayesian_ridge_jit(prior_params, inputs)

In [104]:
f(inputs)

(DeviceArray([0.1629698], dtype=float32),
 DeviceArray([199.29936], dtype=float32))

In [105]:
jax.vjp(f, inputs)[1]((jnp.ones((1, )), 1.0))

TypeError: Tree structure of cotangent input PyTreeDef(tuple, [*,*,*,*]), does not match structure of primal output PyTreeDef(tuple, [*,*])