For some reason, the custom back propagation of the SBL fails. Let's try and recreate the problem.

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver


%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'

In [2]:
def SBL(
    X,
    y,
    prior_init=None,
    alpha_prior=(1e-6, 1e-6),
    beta_prior=(1e-6, 1e-6),
    tol=1e-3,
    max_iter=300,
    non_diff=True
):
    n_samples, n_features = X.shape
    norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
    if prior_init is None:
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0
        )
    # adding zeros to z for coeffs
    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, metrics = fixed_point_solver(
        update,
        (X, y, gram, XT_y, alpha_prior, beta_prior),
        prior_init,
        norm_weight,
        tol=tol,
        max_iter=max_iter,
    )
    
    if non_diff:
        prior = jax.lax.stop_gradient(prior_params) # no it doesnt backprop through the prior
    else:
        prior = prior_params
    loss, mn = evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior)

    return loss, mn, prior, metrics

def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)


def evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]

    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(sigma)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse_ + jnp.sum(alpha * coeffs.squeeze() ** 2))

    return score.squeeze(), coeffs


In [3]:
def loss_fn_SBL(params, state, model, X, y, warm_restart=True, non_diff=True):
    model_state, loss_state = state
    variables = {"params": params, **model_state}
    (prediction, dt, theta, coeffs), updated_model_state = model.apply(
        variables, X, mutable=list(model_state.keys())
    )

    n_samples, n_features = theta.shape
    prior_params_mse = (0.0, 0.0)

    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)

    # Regression stuff
    # we dont want the gradient
    hyper_prior_params = (
        n_samples / 2,
        n_samples / (2 * jax.lax.stop_gradient(tau)),
    )
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)

    if (loss_state["prior_init"] is None) or (warm_restart is False):
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
        )
    else:
        prior_init = loss_state["prior_init"]

    p_reg, mn, prior, fwd_metric = SBL(
        theta_normed,
        dt,
        prior_init=prior_init,
        beta_prior=hyper_prior_params,
        tol=1e-3,
        max_iter=300,
        non_diff=non_diff
    )

    Reg = jnp.mean((dt - theta_normed @ mn) ** 2)

    loss_state["prior_init"] = prior
    loss = -(p_mse + p_reg)
    metrics = {
        "loss": loss,
        "p_mse": p_mse,
        "mse": MSE,
        "p_reg": p_reg,
        "reg": Reg,
        "bayes_coeffs": mn,
        "coeffs": coeffs,
        "alpha": prior[:-1],
        "beta": prior[-1],
        "tau": tau,
        "its": fwd_metric[0],
        "gap": fwd_metric[1],
    }

    return (
        loss,
        ((updated_model_state, loss_state), metrics, (prediction, dt, theta, mn)),
    )


In [None]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


Let's first try the non-diff version

In [9]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, True))


In [10]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 1000)

Loss step 0: -2912.772705078125
Loss step 100: -7139.28173828125
Loss step 200: -8828.7109375
Loss step 300: -9576.623046875
Loss step 400: -9888.60546875
Loss step 500: -9978.7890625
Loss step 600: -10031.130859375
Loss step 700: -10059.52734375
Loss step 800: -10077.486328125
Loss step 900: -10091.45703125


So that works fine. Now let's try it with diff

In [14]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))


In [15]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 1000)

Loss step 0: -2912.772705078125
Loss step 100: -7139.28271484375
Loss step 200: -8828.697265625
Loss step 300: -9576.599609375
Loss step 400: -9888.5712890625
Loss step 500: -9978.8037109375
Loss step 600: -10031.4814453125
Loss step 700: -10059.626953125
Loss step 800: -10077.3505859375
Loss step 900: -10092.34375


Okay so it actually works the first 1000 iterations for both (although it doesn't seem to make a difference...) I had more issues with kdv, let's try that.

In [4]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [18]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, True))


In [19]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 1000)

Loss step 0: 577.94580078125
Loss step 100: -463.72625732421875
Loss step 200: -1174.6553955078125
Loss step 300: -2272.059326171875
Loss step 400: -2773.503662109375
Loss step 500: -3403.674560546875
Loss step 600: -3751.952880859375
Loss step 700: -4763.6982421875
Loss step 800: -6398.29296875
Loss step 900: -6664.77197265625


So that works fine. Now let's try it with diff

In [5]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))


In [6]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 1000)

Loss step 0: 577.9456787109375


RuntimeError: cuSolver execution failed

So it fails before the first 100 iterations, let's see if we can get a more accurate find:

In [7]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))

In [8]:
for it in jnp.arange(100):
    (optimizer, state), metrics, output = update_fn(optimizer, state)
    print(f"Succesfully run iteration {it}")

Succesfully run iteration 0


RuntimeError: cuSolver execution failed

Okay so the first iteration - non-diff works fine, so it has to do with the backwards pass for the SBL. Let's do a few quickchecks to see if its not jit;

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True, False))

In [None]:
with jax.disable_jit():
    for it in jnp.arange(100):
        (optimizer, state), metrics, output = update_fn(optimizer, state)
        print(f"Succesfully run iteration {it}")

Okay so that still doesn't work, but it's complaining about the coeff / least squares solution, not about the SBL? Anyway, let's do this in a different notebook.