I'm pretty sure the issue is with the custom back-prop. If we don't backpropagate through alpha we should still get a good result, let's see.

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver


%load_ext autoreload
%autoreload 2

In [2]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


In [8]:
def SBL(
    X,
    y,
    prior_init=None,
    alpha_prior=(1e-6, 1e-6),
    beta_prior=(1e-6, 1e-6),
    tol=1e-3,
    max_iter=300,
):
    n_samples, n_features = X.shape
    norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
    if prior_init is None:
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0
        )
    # adding zeros to z for coeffs
    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, metrics = fixed_point_solver(
        update,
        (X, y, gram, XT_y, alpha_prior, beta_prior),
        prior_init,
        norm_weight,
        tol=tol,
        max_iter=max_iter,
    )
    
    prior = jax.lax.stop_gradient(prior_params) # no it doesnt backprop through the prior
    loss, mn = evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior)

    return loss, mn, prior, metrics


def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)


def evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]

    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(sigma)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse_ + jnp.sum(alpha * coeffs.squeeze() ** 2))

    return score.squeeze(), coeffs


In [9]:
def loss_fn_SBL(params, state, model, X, y, warm_restart=True):
    model_state, loss_state = state
    variables = {"params": params, **model_state}
    (prediction, dt, theta, coeffs), updated_model_state = model.apply(
        variables, X, mutable=list(model_state.keys())
    )

    n_samples, n_features = theta.shape
    prior_params_mse = (0.0, 0.0)

    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)

    # Regression stuff
    # we dont want the gradient
    hyper_prior_params = (
        n_samples / 2,
        n_samples / (2 * jax.lax.stop_gradient(tau)),
    )
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)

    if (loss_state["prior_init"] is None) or (warm_restart is False):
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
        )
    else:
        prior_init = loss_state["prior_init"]

    p_reg, mn, prior, fwd_metric = SBL(
        theta_normed,
        dt,
        prior_init=prior_init,
        beta_prior=hyper_prior_params,
        tol=1e-3,
        max_iter=300,
    )

    Reg = jnp.mean((dt - theta_normed @ mn) ** 2)

    loss_state["prior_init"] = prior
    loss = -(p_mse + p_reg)
    metrics = {
        "loss": loss,
        "p_mse": p_mse,
        "mse": MSE,
        "p_reg": p_reg,
        "reg": Reg,
        "bayes_coeffs": mn,
        "coeffs": coeffs,
        "alpha": prior[:-1],
        "beta": prior[-1],
        "tau": tau,
        "its": fwd_metric[0],
        "gap": fwd_metric[1],
    }

    return (
        loss,
        ((updated_model_state, loss_state), metrics, (prediction, dt, theta, mn)),
    )


In [10]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True))


In [11]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 5000)

Loss step 0: -2912.77294921875
Loss step 100: -7139.28369140625
Loss step 200: -8828.705078125
Loss step 300: -9576.6103515625
Loss step 400: -9888.5791015625
Loss step 500: -9978.798828125
Loss step 600: -10031.1171875
Loss step 700: -10059.4697265625
Loss step 800: -10077.68359375
Loss step 900: -10092.5712890625
Loss step 1000: -10095.96484375
Loss step 1100: -10100.787109375
Loss step 1200: -10103.408203125
Loss step 1300: -10103.4306640625
Loss step 1400: -10104.515625
Loss step 1500: -10104.8994140625
Loss step 1600: -10105.02734375
Loss step 1700: -10104.728515625
Loss step 1800: -10104.365234375
Loss step 1900: -10105.572265625
Loss step 2000: -10105.1181640625
Loss step 2100: -10105.912109375
Loss step 2200: -10105.462890625
Loss step 2300: -10104.85546875
Loss step 2400: -10106.134765625
Loss step 2500: -10106.314453125
Loss step 2600: -10106.23828125
Loss step 2700: -10106.544921875
Loss step 2800: -10105.14453125
Loss step 2900: -10106.2958984375
Loss step 3000: -10105.8378

Wow that worked amazing - let's try it with kdv

In [12]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [13]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True))


In [14]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 5000)

Loss step 0: 577.9468994140625
Loss step 100: -463.724853515625
Loss step 200: -1174.6539306640625
Loss step 300: -2272.0556640625
Loss step 400: -2773.500732421875
Loss step 500: -3403.671142578125
Loss step 600: -3751.59521484375
Loss step 700: -4755.078125
Loss step 800: -6396.3857421875
Loss step 900: -6653.75439453125
Loss step 1000: -6718.2197265625
Loss step 1100: -6740.65380859375
Loss step 1200: -6745.4296875
Loss step 1300: -6745.13818359375
Loss step 1400: -6755.9833984375
Loss step 1500: -6752.78564453125
Loss step 1600: -6723.474609375
Loss step 1700: -6758.52392578125
Loss step 1800: -6756.23291015625
Loss step 1900: -6753.6591796875
Loss step 2000: -6755.51025390625
Loss step 2100: -6759.00390625
Loss step 2200: -6763.03173828125
Loss step 2300: -6755.2734375
Loss step 2400: -6752.0966796875
Loss step 2500: -6763.1015625
Loss step 2600: -6758.123046875
Loss step 2700: -6765.3212890625
Loss step 2800: -6754.98291015625
Loss step 2900: -6760.314453125
Loss step 3000: -6762