We've build an implementation, now to make sure the SBL function works.

In [8]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.linear_model.SBL import SBL


from flax.core import unfreeze
from flax.traverse_util import flatten_dict
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Making data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

state, params = variables.pop("params")

prediction, dt, theta, coeffs = model.apply({"params": params, **state}, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))
hyper_prior = (alpha_prior, beta_prior)

# testing

In [9]:
SBL(theta_normed, dt, None, hyper_prior, tol=1e-5, max_iter=1000, stop_prior_grad=True)

(DeviceArray(455.70856, dtype=float32),
 DeviceArray([[ 4.8159838e-02],
              [-1.7831131e-05],
              [ 6.5489697e-01],
              [-8.3192805e-05],
              [ 1.6062738e+00],
              [ 6.3539577e-05],
              [-6.0119817e-04],
              [ 7.5133808e-04],
              [-3.0716820e-04],
              [ 4.9009838e-04],
              [ 6.5006682e-04],
              [-4.6149743e-04]], dtype=float32),
 DeviceArray([3.8262371e+01, 1.3062083e+03, 1.5917872e+00, 1.4582084e+03,
              3.5256827e-01, 1.4764838e+03, 1.0433547e+03, 1.2462081e+03,
              8.3055188e+02, 1.3493473e+03, 8.2135358e+02, 1.3354717e+03,
              4.2957764e+00], dtype=float32),
 DeviceArray(471, dtype=int32))