We know we don't need the gradient, (or at least for now), so let's make the code for the SBL a lot better by correctly doing 1) the cholesky decomp and 2) implementing the evidence convegrence criterium. In this notebook we focus on point 1).

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver, fwd_solver, fwd_solver_simple

from flax.core import unfreeze
from flax.traverse_util import flatten_dict
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'

# Making data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

state, params = variables.pop("params")

prediction, dt, theta, coeffs = model.apply({"params": params, **state}, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))

In [6]:
n_samples, n_features = theta_normed.shape
norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
prior_init = jnp.concatenate([jnp.ones((n_features,)), (1.0 / (jnp.var(dt) + 1e-7))[jnp.newaxis]], axis=0)
gram = jnp.dot(theta_normed.T, theta_normed)
XT_y = jnp.dot(theta_normed.T, dt)

tol = 1e-3
max_iter = 1000 # low to keep it manageable

# Baseline

In [7]:
def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

def evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]

    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(sigma)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse_ + jnp.sum(alpha * coeffs.squeeze() ** 2))

    return score.squeeze(), coeffs

In [8]:
# once for jit
prior_params_baseline, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [9]:
%%time
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

CPU times: user 284 ms, sys: 3.32 ms, total: 288 ms
Wall time: 286 ms


In [10]:
prior_params_baseline

DeviceArray([3.8261593e+01, 1.4216833e+03, 1.5917912e+00, 1.4647550e+03,
             3.5256863e-01, 1.4051538e+03, 1.0434841e+03, 1.2460419e+03,
             8.3008948e+02, 1.3492664e+03, 8.2153131e+02, 1.3355718e+03,
             4.2957764e+00], dtype=float32)

# V1 - switching to cholesky

In [11]:
from jax.numpy.linalg import cholesky
from jax.scipy.linalg import solve_triangular

In [12]:
def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L = cholesky(sigma_inv)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    
    return jnp.dot(R.T, R)


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

In [13]:
# once for jit
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [14]:
%%time
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

CPU times: user 58.9 ms, sys: 0 ns, total: 58.9 ms
Wall time: 58 ms


In [15]:
jnp.abs(prior_params_baseline - prior_params) / prior_params_baseline

DeviceArray([6.4805281e-06, 3.1433667e-03, 4.4934018e-07, 3.0143492e-04,
             1.5215245e-06, 5.6858559e-04, 2.7151845e-04, 1.2637673e-05,
             4.5822901e-04, 8.2238701e-05, 2.4153102e-04, 4.4877052e-05,
             0.0000000e+00], dtype=float32)

Perfect - now in the next step let's merge sigma and mu and do it all efficiently.

# V1 - everything by cholesky

In [13]:
from jax.scipy.linalg import solve_triangular
from jax.numpy.linalg import cholesky

In [17]:
def update_posterior(gram, XT_y, alpha, beta):
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)
    
    return mean, sigma

def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    coeffs, sigma = update_posterior(gram, XT_y, alpha, beta)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

In [18]:
# once for jit
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [19]:
%%time
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

CPU times: user 58.5 ms, sys: 185 µs, total: 58.7 ms
Wall time: 58 ms


In [20]:
jnp.abs(prior_params_baseline - prior_params) / prior_params_baseline

DeviceArray([6.4805281e-06, 3.1433667e-03, 4.4934018e-07, 3.0143492e-04,
             1.5215245e-06, 5.6858559e-04, 2.7151845e-04, 1.2637673e-05,
             4.5822901e-04, 8.2238701e-05, 2.4153102e-04, 4.4877052e-05,
             0.0000000e+00], dtype=float32)

In [21]:
prior_params

DeviceArray([3.8261841e+01, 1.4261522e+03, 1.5917919e+00, 1.4651965e+03,
             3.5256809e-01, 1.4059528e+03, 1.0432008e+03, 1.2460576e+03,
             8.3046985e+02, 1.3493773e+03, 8.2133289e+02, 1.3355118e+03,
             4.2957764e+00], dtype=float32)

# V2 - Putting alpha and beta in separate function

In [22]:
from jax.scipy.linalg import solve_triangular
from jax.numpy.linalg import cholesky

In [23]:
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)
    
    return mean, sigma

def update_prior(X, y, posterior, prior, hyper_prior):
    mean, covariance = posterior
    alpha, beta = prior
    alpha_prior, beta_prior = hyper_prior
    
  
    rmse = jnp.sum((y - jnp.dot(X, mean)) ** 2)
    gamma = 1.0 - alpha * jnp.diag(covariance)

    # Update alpha and beta
    alpha = (gamma + 2.0 * alpha_prior[0]) / (
        (mean.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - jnp.sum(gamma) + 2.0 * beta_prior[0]) / (
        rmse + 2.0 * beta_prior[1]
    )

    return jnp.minimum(1e7, alpha), jnp.minimum(1e7, beta)


def update(prior, X, y, gram, XT_y, hyper_prior):
    n_samples, n_features = X.shape
    prior = (prior[:-1], prior[-1])
    posterior = update_posterior(gram, XT_y, prior)
    alpha, beta = update_prior(X, y, posterior, prior, hyper_prior)
  
    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

In [24]:
# once for jit
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, (alpha_prior, beta_prior)),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [25]:
%%time
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, (alpha_prior, beta_prior)),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

CPU times: user 57.4 ms, sys: 74 µs, total: 57.5 ms
Wall time: 56.6 ms


In [26]:
jnp.abs(prior_params_baseline - prior_params) / prior_params_baseline

DeviceArray([6.4805281e-06, 3.1433667e-03, 4.4934018e-07, 3.0143492e-04,
             1.5215245e-06, 5.6858559e-04, 2.7151845e-04, 1.2637673e-05,
             4.5822901e-04, 8.2238701e-05, 2.4153102e-04, 4.4877052e-05,
             0.0000000e+00], dtype=float32)

In [27]:
prior_params

DeviceArray([3.8261841e+01, 1.4261522e+03, 1.5917919e+00, 1.4651965e+03,
             3.5256809e-01, 1.4059528e+03, 1.0432008e+03, 1.2460576e+03,
             8.3046985e+02, 1.3493773e+03, 8.2133289e+02, 1.3355118e+03,
             4.2957764e+00], dtype=float32)

In [28]:
metrics

(DeviceArray(1000, dtype=int32), DeviceArray(75.63577, dtype=float32))

# v3 putting in evidence

As a convergence criterium I want to use the mean gradient of the loss - luckily, there's an analytical expression in tipping. Let's implement it.

In [7]:
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)
    
    return mean, sigma

def update_prior(X, y, posterior, prior, hyper_prior):
    mean, covariance = posterior
    alpha, beta = prior
    alpha_prior, beta_prior = hyper_prior
    
  
    rmse = jnp.sum((y - jnp.dot(X, mean)) ** 2)
    gamma = 1.0 - alpha * jnp.diag(covariance)

    # Update alpha and beta
    alpha = (gamma + 2.0 * alpha_prior[0]) / (
        (mean.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - jnp.sum(gamma) + 2.0 * beta_prior[0]) / (
        rmse + 2.0 * beta_prior[1]
    )
    
    # Calculating dL/da
    dLda = 1/2 * (1 / alpha - (mean.squeeze()**2 + jnp.diag(covariance))) + alpha_prior[0] / alpha - alpha_prior[1]
    return jnp.minimum(1e7, alpha), jnp.minimum(1e7, beta), jnp.mean(jnp.abs(dLda))


def update(prior, X, y, gram, XT_y, hyper_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-2], prior[-2]
    posterior = update_posterior(gram, XT_y, (alpha, beta))
    alpha, beta, loss_grad = update_prior(X, y, posterior, (alpha, beta), hyper_prior)
  
    return jnp.concatenate([alpha, beta[jnp.newaxis], loss_grad[jnp.newaxis]], axis=0)

def evidence(X, y, gram, XT_y, prior, hyper_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    alpha_prior, beta_prior = hyper_prior
    
    mean, covariance = update_posterior(gram, XT_y, (alpha, beta))
    rmse = jnp.sum((y - jnp.dot(X, mean)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(covariance)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse + jnp.sum(alpha * mean.squeeze() ** 2))

    return score.squeeze(), coeffs

In [56]:
norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((2,))), axis=0)
prior_init = jnp.concatenate([jnp.ones((n_features,)), (1.0 / (jnp.var(dt) + 1e-7))[jnp.newaxis], jnp.ones((1, ))], axis=0)

In [57]:
# once for jit
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, (alpha_prior, beta_prior)),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [58]:
%%time
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, (alpha_prior, beta_prior)),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

CPU times: user 57.4 ms, sys: 68 µs, total: 57.4 ms
Wall time: 56.5 ms


In [59]:
prior_params[-1]

DeviceArray(2.2866866e-06, dtype=float32)

That seems higher then what I observed when taking the gradient of the evidence - let's check:

In [60]:
dprior = jax.grad(lambda prior: evidence(theta_normed, dt, gram, XT_y, prior, (alpha_prior, beta_prior))[0])(prior_params[:-1])

In [61]:
jnp.mean(jnp.abs(dprior[:-2]))

DeviceArray(7.2724504e-08, dtype=float32)

It def is but with these numbers I think there's some numerical issue. Now let's build a function using a custom cond_fun

# Full SBL function

In [8]:
from jax import jit
from functools import partial
from jax import lax

In [138]:
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = cholesky(jnp.diag(alpha) + beta * gram)
    R = solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)
    
    return mean, sigma

def update_prior(X, y, posterior, prior, hyper_prior):
    mean, covariance = posterior
    alpha, beta = prior
    alpha_prior, beta_prior = hyper_prior
    
    rmse = jnp.sum((y - jnp.dot(X, mean)) ** 2)
    gamma = 1.0 - alpha * jnp.diag(covariance)

    # Update alpha and beta
    alpha = (gamma + 2.0 * alpha_prior[0]) / (
        (mean.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - jnp.sum(gamma) + 2.0 * beta_prior[0]) / (
        rmse + 2.0 * beta_prior[1]
    )
    
    # Calculating dL/da
    dLda = 1 / alpha * (1/2 * (1 - alpha * (mean.squeeze()**2 + jnp.diag(covariance))) + alpha_prior[0] - alpha * alpha_prior[1])
    return jnp.minimum(1e7, alpha), jnp.minimum(1e7, beta), jnp.mean(jnp.abs(dLda))


def update(prior, X, y, gram, XT_y, hyper_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-2], prior[-2]
    posterior = update_posterior(gram, XT_y, (alpha, beta))
    alpha, beta, loss_grad = update_prior(X, y, posterior, (alpha, beta), hyper_prior)
  
    return jnp.concatenate([alpha, beta[jnp.newaxis], loss_grad[jnp.newaxis]], axis=0)

def evidence(X, y, gram, XT_y, prior, hyper_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    alpha_prior, beta_prior = hyper_prior
    
    mean, covariance = update_posterior(gram, XT_y, (alpha, beta))
    rmse = jnp.sum((y - jnp.dot(X, mean)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(covariance)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse + jnp.sum(alpha * mean.squeeze() ** 2))

    return score.squeeze(), mean

In [139]:
@partial(jit, static_argnums=(0, 2))
def fwd_solver_2(f, z_init, cond_fun, max_iter=300):
    # n_features calculates the norm over the first n_features of z.
    # Useful for when you're iterating over a but check your convergence on b
    # such as with SBL.
    def _cond_fun(carry):
        z_prev, z, iteration = carry
        return jax.lax.cond(iteration >= max_iter, 
                     lambda _: False, 
                     lambda args: cond_fun(*args), 
                     (z_prev, z))

    def body_fun(carry):
        _, z, iteration = carry
        return z, f(z), iteration + 1

    init_carry = (z_init, f(z_init), 0, )
    z_star, _, metrics = lax.while_loop(_cond_fun, body_fun, init_carry)
    return z_star, metrics

@partial(jit, static_argnums=(0, 3))
def fixed_point_solver_2(f, args, z_init, cond_fun, max_iter=300):
    z_star, metrics = fwd_solver_2(
        lambda z: f(z, *args), z_init, cond_fun, max_iter=max_iter,
    )
    return z_star, metrics

In [140]:
prior_init = jnp.concatenate([jnp.ones((n_features,)), (1.0 / (jnp.var(dt) + 1e-7))[jnp.newaxis], jnp.ones((1, ))], axis=0)

In [141]:
# once for jit
prior_params, metrics = fixed_point_solver_2(
    update,
    (theta_normed, dt, gram, XT_y, (alpha_prior, beta_prior)),
    prior_init,
    lambda z_prev, z: z[-1] > 1e-5,
    max_iter=max_iter,
)

This works well - now let's wrap it in a full SBL function and update the fwd solver.

In [142]:
def SBL(
    X,
    y,
    prior_init=None,
    hyper_prior=((1e-6, 1e-6), (1e-6, 1e-6)), 
    tol=1e-5,
    max_iter=1000,
    stop_prior_grad=True):
    
    n_samples, n_features = X.shape
    if prior_init is None:
        prior_init = jnp.ones((n_features + 2, ))
        prior_init = jax.ops.index_update(prior_init, n_features, 1 / (jnp.var(y) + 1e-6)) # setting initial noise value
 
    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, metrics = fixed_point_solver_2(
        update,
        (X, y, gram, XT_y, hyper_prior),
        prior_init,
        lambda z_prev, z: z[-1] > tol,
        max_iter=max_iter,
    )
    
    if stop_prior_grad:
        prior = lax.stop_gradient(prior_params[:-1])
    else:
        prior = prior_params[:-1]
        
    loss, mn = evidence(X, y, gram, XT_y, prior, hyper_prior)
    return loss, mn, prior, metrics

In [143]:
%%time
SBL(theta_normed, dt, None, (alpha_prior, beta_prior), tol=1e-5, max_iter=1000)

CPU times: user 628 ms, sys: 10.4 ms, total: 638 ms
Wall time: 247 ms


(DeviceArray(455.7086, dtype=float32),
 DeviceArray([[ 4.8160609e-02],
              [-1.4313304e-05],
              [ 6.5489328e-01],
              [-8.3149920e-05],
              [ 1.6062721e+00],
              [ 7.1582544e-05],
              [-6.0125935e-04],
              [ 7.5151841e-04],
              [-3.0715278e-04],
              [ 4.9016980e-04],
              [ 6.5001904e-04],
              [-4.6155893e-04]], dtype=float32),
 DeviceArray([3.8261742e+01, 1.6299744e+03, 1.5917979e+00, 1.4587059e+03,
              3.5256785e-01, 1.3100173e+03, 1.0432620e+03, 1.2459415e+03,
              8.3059930e+02, 1.3490701e+03, 8.2140729e+02, 1.3352468e+03,
              4.2957759e+00], dtype=float32),
 DeviceArray(472, dtype=int32))

In [144]:
%%time
SBL(theta_normed, dt, None, (alpha_prior, beta_prior), tol=1e-5, max_iter=1000)

CPU times: user 628 ms, sys: 0 ns, total: 628 ms
Wall time: 247 ms


(DeviceArray(455.7086, dtype=float32),
 DeviceArray([[ 4.8160609e-02],
              [-1.4313304e-05],
              [ 6.5489328e-01],
              [-8.3149920e-05],
              [ 1.6062721e+00],
              [ 7.1582544e-05],
              [-6.0125935e-04],
              [ 7.5151841e-04],
              [-3.0715278e-04],
              [ 4.9016980e-04],
              [ 6.5001904e-04],
              [-4.6155893e-04]], dtype=float32),
 DeviceArray([3.8261742e+01, 1.6299744e+03, 1.5917979e+00, 1.4587059e+03,
              3.5256785e-01, 1.3100173e+03, 1.0432620e+03, 1.2459415e+03,
              8.3059930e+02, 1.3490701e+03, 8.2140729e+02, 1.3352468e+03,
              4.2957759e+00], dtype=float32),
 DeviceArray(472, dtype=int32))