We know how to recreate the problem and ensure it's not the jit - now let's try and figure out why.

In [2]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver

from flax.core import unfreeze
from flax.traverse_util import flatten_dict
%load_ext autoreload
%autoreload 2

In [3]:
def SBL(
    X,
    y,
    prior_init=None,
    alpha_prior=(1e-6, 1e-6),
    beta_prior=(1e-6, 1e-6),
    tol=1e-3,
    max_iter=300,
    non_diff=True
):
    n_samples, n_features = X.shape
    norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
    if prior_init is None:
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0
        )
    # adding zeros to z for coeffs
    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, metrics = fixed_point_solver(
        update,
        (X, y, gram, XT_y, alpha_prior, beta_prior),
        prior_init,
        norm_weight,
        tol=tol,
        max_iter=max_iter,
    )
    
    if non_diff:
        prior = jax.lax.stop_gradient(prior_params) # no it doesnt backprop through the prior
    else:
        prior = prior_params
    loss, mn = evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior)

    return loss, mn, prior, metrics

def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)


def evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]

    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(sigma)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse_ + jnp.sum(alpha * coeffs.squeeze() ** 2))

    return score.squeeze(), coeffs


In [4]:
def loss_fn_SBL(params, state, model, X, y, warm_restart=True, non_diff=True):
    model_state, loss_state = state
    variables = {"params": params, **model_state}
    (prediction, dt, theta, coeffs), updated_model_state = model.apply(
        variables, X, mutable=list(model_state.keys())
    )

    n_samples, n_features = theta.shape
    prior_params_mse = (0.0, 0.0)

    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)

    # Regression stuff
    # we dont want the gradient
    hyper_prior_params = (
        n_samples / 2,
        n_samples / (2 * jax.lax.stop_gradient(tau)),
    )
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)

    if (loss_state["prior_init"] is None) or (warm_restart is False):
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
        )
    else:
        prior_init = loss_state["prior_init"]

    p_reg, mn, prior, fwd_metric = SBL(
        theta_normed,
        dt,
        prior_init=prior_init,
        beta_prior=hyper_prior_params,
        tol=1e-3,
        max_iter=300,
        non_diff=non_diff
    )

    Reg = jnp.mean((dt - theta_normed @ mn) ** 2)

    updated_loss_state = {"prior_init": prior}
    loss = -(p_mse + p_reg)
    metrics = {
        "loss": loss,
        "p_mse": p_mse,
        "mse": MSE,
        "p_reg": p_reg,
        "reg": Reg,
        "bayes_coeffs": mn,
        "coeffs": coeffs,
        "alpha": prior[:-1],
        "beta": prior[-1],
        "tau": tau,
        "its": fwd_metric[0],
        "gap": fwd_metric[1],
    }

    return (
        loss,
        ((updated_model_state, updated_loss_state), metrics, (prediction, dt, theta, mn)),
    )


In [5]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [5]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))

To pinpoint the problem,let's look at the params after the first epoch (without jit)

In [6]:
with jax.disable_jit():
    (new_optimizer, new_state), metrics, output = update_fn(optimizer, state)

In [7]:
state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': None})

In [8]:
new_state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': DeviceArray([3.8261806e+01, 1.2139067e+03, 1.5917789e+00, 1.3994301e+03,
               3.5256889e-01, 1.6201604e+03, 1.0433588e+03, 1.2462245e+03,
               8.3063586e+02, 1.3484807e+03, 8.2134668e+02, 1.3358071e+03,
               4.2957764e+00], dtype=float32)})

Wait so I actually overwrote the loss_state variable; let's fix this

In [9]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))

In [10]:
with jax.disable_jit():
    (new_optimizer, new_state), metrics, output = update_fn(optimizer, state)

In [11]:
state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': None})

In [12]:
new_state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': DeviceArray([3.8261806e+01, 1.2139067e+03, 1.5917789e+00, 1.3994301e+03,
               3.5256889e-01, 1.6201604e+03, 1.0433588e+03, 1.2462245e+03,
               8.3063586e+02, 1.3484807e+03, 8.2134668e+02, 1.3358071e+03,
               4.2957764e+00], dtype=float32)})

Now it's correct. The prior doesn't seem unnecessarily big; let's check the rest

In [13]:
new_optimizer.target

FrozenDict({
    MLP_0: {
        Dense_0: {
            bias: DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                         nan, nan, nan, nan, nan, nan], dtype=float32),
            kernel: DeviceArray([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                          nan, nan, nan, nan, nan, nan],
                         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                          nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                          nan, nan, nan, nan, nan, nan]], dtype=float32),
        },
        Dense_1: {
            bias: DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                  

Well that explains it - everything is nan - weird. How does the output look?

In [14]:
print(f"Nans: {[jnp.any(jnp.isnan(i)) for i in output]}")
print(f"Infs: {[jnp.any(jnp.isinf(i)) for i in output]}")
print(f"min: {[jnp.min(i) for i in output]}")
print(f"max: {[jnp.max(i) for i in output]}")

Nans: [DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool)]
Infs: [DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool)]
min: [DeviceArray(-0.08321948, dtype=float32), DeviceArray(-0.10772296, dtype=float32), DeviceArray(-0.8179425, dtype=float32), DeviceArray(-0.00060113, dtype=float32)]
max: [DeviceArray(0.1909909, dtype=float32), DeviceArray(0.16929603, dtype=float32), DeviceArray(2.387917, dtype=float32), DeviceArray(1.606274, dtype=float32)]


That seems fine, so the nans have to be caused by the SBL backprop - let's verify by running the non-diff again:

In [15]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, True))

In [16]:
with jax.disable_jit():
    (new_optimizer, new_state), metrics, output = update_fn(optimizer, state)

In [17]:
state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': None})

In [18]:
new_state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': DeviceArray([3.8261806e+01, 1.2139067e+03, 1.5917789e+00, 1.3994301e+03,
               3.5256889e-01, 1.6201604e+03, 1.0433588e+03, 1.2462245e+03,
               8.3063586e+02, 1.3484807e+03, 8.2134668e+02, 1.3358071e+03,
               4.2957764e+00], dtype=float32)})

Now it's correct. The prior doesn't seem unnecessarily big; let's check the rest

In [19]:
new_optimizer.target

FrozenDict({
    MLP_0: {
        Dense_0: {
            bias: DeviceArray([-0.00199999,  0.00199999,  0.002     , -0.00199999,
                         -0.00199999,  0.00199999, -0.00199999, -0.00199999,
                         -0.00199999,  0.00199999, -0.002     ,  0.00199999,
                          0.002     , -0.002     , -0.00199999,  0.002     ,
                          0.00199999,  0.00199999,  0.00199999,  0.002     ,
                         -0.002     , -0.00199999, -0.00199999,  0.002     ,
                          0.00199999, -0.00199999,  0.00199999,  0.00199999,
                         -0.00199999,  0.00199999], dtype=float32),
            kernel: DeviceArray([[-0.61155164, -1.1547233 , -0.33419877,  0.63251287,
                           0.40870082,  0.9491058 ,  0.43981987,  0.1844259 ,
                           0.02297306, -0.3432148 ,  0.1914811 , -0.9863748 ,
                           0.5854419 , -0.65472823,  0.03556921, -1.392958  ,
                      

Yeah that's all fine, as expected. Let's see if we can recreate this; it seems the forward pass is fine.

In [20]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)
state, params = variables.pop("params")
state = (state, {"prior_init": None})  # adding prior to state

In [21]:
# Forward pass
model_state, loss_state = state
variables = {"params": params, **model_state}
(prediction, dt, theta, coeffs), updated_model_state = model.apply(
    variables, X, mutable=list(model_state.keys())
)


In [22]:
output = (prediction, dt, theta, coeffs)
print(f"Nans: {[jnp.any(jnp.isnan(i)) for i in output]}")
print(f"Infs: {[jnp.any(jnp.isinf(i)) for i in output]}")
print(f"min: {[jnp.min(i) for i in output]}")
print(f"max: {[jnp.max(i) for i in output]}")

Nans: [DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool)]
Infs: [DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool), DeviceArray(False, dtype=bool)]
min: [DeviceArray(-0.08321948, dtype=float32), DeviceArray(-0.10772296, dtype=float32), DeviceArray(-0.8179425, dtype=float32), DeviceArray(-20.45381, dtype=float32)]
max: [DeviceArray(0.1909909, dtype=float32), DeviceArray(0.16929603, dtype=float32), DeviceArray(2.387917, dtype=float32), DeviceArray(2.1001604, dtype=float32)]


as expected, thats fine.

In [23]:
# MSE stuff
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)
p_mse, MSE = normal_LL(prediction, y, tau)

In [24]:
# Regression stuff
# we dont want the gradient

n_samples, n_features = theta.shape
hyper_prior_params = (
    n_samples / 2,
    n_samples / (2 * jax.lax.stop_gradient(tau)),
)

theta_normed = theta / jnp.linalg.norm(theta, axis=0)

prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
        )
    
p_reg, mn, prior, fwd_metric = SBL(
    theta_normed,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=300,
    non_diff=False
)

Let's look at the results:

In [25]:
print(p_reg, mn, prior)

455.7085 [[ 4.8160817e-02]
 [-1.9316360e-05]
 [ 6.5489900e-01]
 [-8.6727006e-05]
 [ 1.6062734e+00]
 [ 5.8015270e-05]
 [-6.0136005e-04]
 [ 7.5137132e-04]
 [-3.0721983e-04]
 [ 4.9039809e-04]
 [ 6.4977107e-04]
 [-4.6146798e-04]] [3.8261505e+01 1.2122681e+03 1.5917799e+00 1.3987809e+03 3.5256892e-01
 1.6174141e+03 1.0430642e+03 1.2461332e+03 8.3042767e+02 1.3485659e+03
 8.2173694e+02 1.3355662e+03 4.2957764e+00]


In [26]:
print(fwd_metric)

(DeviceArray(300, dtype=int32), DeviceArray(704.2981, dtype=float32))


That a fairly big gap, but that shouldn't be the issue

In [27]:
Reg = jnp.mean((dt - theta_normed @ mn) ** 2)
loss = -(p_mse + p_reg)

In [28]:
print(loss)

577.9458


Now as the code works when we use the non-diff version of alpha; the problem should manifest itself if we try to calculate the grad of the prior. Lets try and calculate the grad of the evidence first

In [29]:
prior

DeviceArray([3.8261505e+01, 1.2122681e+03, 1.5917799e+00, 1.3987809e+03,
             3.5256892e-01, 1.6174141e+03, 1.0430642e+03, 1.2461332e+03,
             8.3042767e+02, 1.3485659e+03, 8.2173694e+02, 1.3355662e+03,
             4.2957764e+00], dtype=float32)

In [30]:
grad_evidence = jax.grad(lambda X: SBL(
    X,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=300,
    non_diff=False
)[0])
grad_evidence(theta_normed)

DeviceArray([[nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             ...,
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

Right. Now let's do the same without the alpha grad:

In [31]:
grad_evidence = jax.grad(lambda X: SBL(
    X,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=300,
    non_diff=True
)[0])
grad_evidence(theta_normed)

DeviceArray([[-1.2440996e-03,  4.1718004e-06,  2.5634980e-02, ...,
               1.9654544e-05,  1.8911585e-05, -2.0199899e-05],
             [-1.3899140e-03,  4.3380123e-06,  2.3853820e-02, ...,
               1.8198010e-05,  1.7685046e-05, -1.8724470e-05],
             [-1.5445322e-03,  4.5053612e-06,  2.1956662e-02, ...,
               1.6653010e-05,  1.6358415e-05, -1.7162565e-05],
             ...,
             [ 5.0728256e-04,  4.9672240e-06,  3.1806353e-02, ...,
               3.5711520e-05, -2.1013140e-05, -4.3645468e-05],
             [ 4.1561108e-04,  4.9216474e-06,  3.0796958e-02, ...,
               3.4790366e-05, -2.1361724e-05, -4.2672349e-05],
             [ 3.2548304e-04,  4.8696011e-06,  2.9805612e-02, ...,
               3.3884531e-05, -2.1701526e-05, -4.1715357e-05]],            dtype=float32)

As expected, that works. Now let's check the alpha

In [38]:
grad_alpha = jax.grad(lambda X: SBL(
    X,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=5000,
    non_diff=False
)[2][0])
grad_alpha(theta_normed)

DeviceArray([[ 1.48259335e+01,  1.13631815e-01,  4.15236244e+01, ...,
               2.46541947e-02,  3.99421901e-02, -2.55021565e-02],
             [ 1.62474537e+01,  1.18221581e-01,  4.12807884e+01, ...,
               2.53305510e-02,  4.05450463e-02, -2.55433004e-02],
             [ 1.77543659e+01,  1.22841269e-01,  4.10161934e+01, ...,
               2.60445736e-02,  4.11799699e-02, -2.55829357e-02],
             ...,
             [-3.03727722e+00,  1.36570469e-01,  3.10729713e+01, ...,
               9.37876478e-03,  2.39835773e-02, -1.68121438e-02],
             [-2.13769913e+00,  1.35279939e-01,  3.10193806e+01, ...,
               9.86818224e-03,  2.44306866e-02, -1.69030670e-02],
             [-1.25323868e+00,  1.33809865e-01,  3.09675369e+01, ...,
               1.03502646e-02,  2.48709843e-02, -1.69931948e-02]],            dtype=float32)

In [36]:
grad_alpha = jax.grad(lambda X: SBL(
    X,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=300,
    non_diff=False
)[2][0])
grad_alpha(theta_normed)

DeviceArray([[nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             ...,
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

Aha so I get Nan's when i have a to low number of iterations!

In [34]:
grad_alpha = jax.grad(lambda X: SBL(
    X,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=300,
    non_diff=True
)[2][0])
grad_alpha(theta_normed)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

So we surely found the cause; is the backprop in the alpha. Let's do that in a separate notebook.

So if we use much more iterations for the SBL we dont get nans? let's check by running it:

In [6]:
def loss_fn_SBL(params, state, model, X, y, warm_restart=True, non_diff=True):
    model_state, loss_state = state
    variables = {"params": params, **model_state}
    (prediction, dt, theta, coeffs), updated_model_state = model.apply(
        variables, X, mutable=list(model_state.keys())
    )

    n_samples, n_features = theta.shape
    prior_params_mse = (0.0, 0.0)

    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)

    # Regression stuff
    # we dont want the gradient
    hyper_prior_params = (
        n_samples / 2,
        n_samples / (2 * jax.lax.stop_gradient(tau)),
    )
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)

    if (loss_state["prior_init"] is None) or (warm_restart is False):
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
        )
    else:
        prior_init = loss_state["prior_init"]

    p_reg, mn, prior, fwd_metric = SBL(
        theta_normed,
        dt,
        prior_init=prior_init,
        beta_prior=hyper_prior_params,
        tol=1e-3,
        max_iter=5000,
        non_diff=non_diff
    )

    Reg = jnp.mean((dt - theta_normed @ mn) ** 2)

    updated_loss_state = {"prior_init": prior}
    loss = -(p_mse + p_reg)
    metrics = {
        "loss": loss,
        "p_mse": p_mse,
        "mse": MSE,
        "p_reg": p_reg,
        "reg": Reg,
        "bayes_coeffs": mn,
        "coeffs": coeffs,
        "alpha": prior[:-1],
        "beta": prior[-1],
        "tau": tau,
        "its": fwd_metric[0],
        "gap": fwd_metric[1],
    }

    return (
        loss,
        ((updated_model_state, updated_loss_state), metrics, (prediction, dt, theta, mn)),
    )


In [64]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))

In [None]:
for i in jnp.arange(5000):
    (optimizer, state), metrics, output = update_fn(optimizer, state)
    if i % 100 == 0:
        has_nan = jnp.any(jnp.array([jnp.any(jnp.isnan(item)) for _, item in flatten_dict(unfreeze(optimizer.target)).items()]))
        mse = metrics["mse"]
        print(f"Done with {i}, Nans: {has_nan}, MSE: {mse}")

Done with 0, Nans: False, MSE: 0.4627499580383301
Done with 100, Nans: False, MSE: 1.0874793529510498
Done with 200, Nans: False, MSE: 3.0515739917755127


In [60]:
output[3]

Buffer([[-7.0177660e+00],
        [ 3.8864780e+01],
        [-1.8419664e+01],
        [ 1.1504240e-03],
        [-2.4564800e+00],
        [ 4.5422382e+01],
        [ 8.7160920e-04],
        [-2.5117257e+00],
        [-8.7517255e-04],
        [ 4.3179214e-01],
        [-1.4310470e-03],
        [ 3.1985741e-03]], dtype=float32)

In [7]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, True))

In [None]:
for i in jnp.arange(5000):
    (optimizer, state), metrics, output = update_fn(optimizer, state)
    if i % 100 == 0:
        has_nan = jnp.any(jnp.array([jnp.any(jnp.isnan(item)) for _, item in flatten_dict(unfreeze(optimizer.target)).items()]))
        mse = metrics["mse"]
        print(f"Done with {i}, Nans: {has_nan}, MSE: {mse}")

Done with 0, Nans: False, MSE: 0.4627499580383301
Done with 100, Nans: False, MSE: 0.22842560708522797
Done with 200, Nans: False, MSE: 0.14117945730686188
Done with 300, Nans: False, MSE: 0.06718077510595322
Done with 400, Nans: False, MSE: 0.048002939671278
Done with 500, Nans: False, MSE: 0.031150447204709053
Done with 600, Nans: False, MSE: 0.024588650092482567
Done with 700, Nans: False, MSE: 0.011292420327663422
Done with 800, Nans: False, MSE: 0.003878676798194647
Done with 900, Nans: False, MSE: 0.0035111813340336084
Done with 1000, Nans: False, MSE: 0.0034217783249914646
Done with 1100, Nans: False, MSE: 0.003423253307119012
Done with 1200, Nans: False, MSE: 0.0034096508752554655
Done with 1300, Nans: False, MSE: 0.003411792917177081
Done with 1400, Nans: False, MSE: 0.003428456373512745
Done with 1500, Nans: False, MSE: 0.0034271518234163523
Done with 1600, Nans: False, MSE: 0.0033987085334956646
Done with 1700, Nans: False, MSE: 0.0034088718239217997
Done with 1800, Nans: Fa