We've build an implementation, now to make sure the SBL function works and the training as well.

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.training.losses.SBL import loss_fn_SBL
from modax.linear_model.SBL import SBL

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'

# Making data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

state, params = variables.pop("params")

prediction, dt, theta, coeffs = model.apply({"params": params, **state}, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))
hyper_prior = (alpha_prior, beta_prior)

# testing function

In [6]:
loss, coeffs, prior, metrics = SBL(theta_normed, dt, None, hyper_prior, tol=1e-4, max_iter=1000)

Now let's try with a prior init:

In [7]:
prior.shape

(13,)

In [8]:
SBL(theta_normed, dt, prior, hyper_prior, tol=1e-4, max_iter=1000)

(DeviceArray(455.70786, dtype=float32),
 DeviceArray([[ 4.8160080e-02],
              [-1.2110626e-05],
              [ 6.5487993e-01],
              [-5.9330454e-05],
              [ 1.6062669e+00],
              [ 1.1196396e-04],
              [-6.0146576e-04],
              [ 7.5185619e-04],
              [-3.0715723e-04],
              [ 4.8824094e-04],
              [ 6.4991473e-04],
              [-4.6195151e-04]], dtype=float32),
 DeviceArray([3.8262333e+01, 1.9324736e+03, 1.5918189e+00, 2.0449115e+03,
              3.5256574e-01, 8.3650781e+02, 1.0429813e+03, 1.2454182e+03,
              8.3056873e+02, 1.3541514e+03, 8.2148560e+02, 1.3340739e+03,
              4.2957740e+00], dtype=float32),
 (DeviceArray(2, dtype=int32),
  DeviceArray(9.817844e-05, dtype=float32),
  DeviceArray([ 4.81602065e-02, -2.11484530e-05,  6.54910684e-01,
               -1.17133066e-04,  1.60627890e+00,  4.04912862e-05,
               -6.01166335e-04,  7.51212006e-04, -3.07230104e-04,
                4.

In [9]:
test = lambda x: SBL(x, dt, prior, hyper_prior, tol=1e-5, max_iter=1000)[0]

In [10]:
test(theta_normed)

DeviceArray(455.70853, dtype=float32)

In [11]:
jax.grad(lambda x: SBL(x, dt, prior, hyper_prior, tol=1e-5, max_iter=1000)[0])(theta_normed)

DeviceArray([[-1.2440758e-03,  3.0951755e-06,  2.5634721e-02, ...,
               1.9645027e-05,  1.8919467e-05, -2.0202955e-05],
             [-1.3898984e-03,  3.2186463e-06,  2.3853416e-02, ...,
               1.8189390e-05,  1.7692322e-05, -1.8727380e-05],
             [-1.5445057e-03,  3.3429608e-06,  2.1956421e-02, ...,
               1.6645543e-05,  1.6365502e-05, -1.7165479e-05],
             ...,
             [ 5.0727581e-04,  3.6803476e-06,  3.1805180e-02, ...,
               3.5687346e-05, -2.1028929e-05, -4.3648466e-05],
             [ 4.1565625e-04,  3.6466095e-06,  3.0796645e-02, ...,
               3.4767319e-05, -2.1376327e-05, -4.2675711e-05],
             [ 3.2548141e-04,  3.6081119e-06,  2.9804520e-02, ...,
               3.3861583e-05, -2.1717336e-05, -4.1718253e-05]],            dtype=float32)

# testing training

In [12]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True))

In [13]:
train_max_iter(update_fn, optimizer, state, 5000)

Loss step 0: 577.9735107421875
Loss step 100: -464.96575927734375
Loss step 200: -1191.73095703125
Loss step 300: -2329.589111328125
Loss step 400: -2796.3388671875
Loss step 500: -3382.09423828125
Loss step 600: -3809.9609375
Loss step 700: -5198.33837890625
Loss step 800: -6581.451171875
Loss step 900: -6708.7314453125
Loss step 1000: -6730.92041015625
Loss step 1100: -6745.6376953125
Loss step 1200: -6759.55615234375
Loss step 1300: -6753.66064453125
Loss step 1400: -6761.625
Loss step 1500: -6766.24365234375
Loss step 1600: -6766.02392578125
Loss step 1700: -6766.37353515625
Loss step 1800: -6748.0341796875
Loss step 1900: -6762.3828125
Loss step 2000: -6765.91162109375
Loss step 2100: -6767.20068359375
Loss step 2200: -6767.728515625
Loss step 2300: -6766.90478515625
Loss step 2400: -6771.333984375
Loss step 2500: -6761.43359375
Loss step 2600: -6767.7705078125
Loss step 2700: -6768.23046875
Loss step 2800: -6760.466796875
Loss step 2900: -6770.3359375
Loss step 3000: -6773.139648

(Optimizer(optimizer_def=<flax.optim.adam.Adam object at 0x7f59072e02b0>, state=OptimizerState(step=DeviceArray(5000, dtype=int32), param_states=FrozenDict({
     MLP_0: {
         Dense_0: {
             bias: _AdamParamState(grad_ema=DeviceArray([ 1.7427711 , -0.44478238, -1.4968688 , -2.3397274 ,
                          -0.67981285,  0.4681831 , -0.24797514, -0.31891882,
                           0.91501766, -0.2255978 , -1.0538245 , -1.6020634 ,
                          -0.8073244 , -0.23093545, -0.05668622,  2.2742722 ,
                          -0.03651446,  0.74621314, -0.30817723,  1.5254296 ,
                           0.02219486, -2.3817186 , -0.27379602, -0.50072443,
                           3.074418  ,  1.314918  ,  1.8695853 , -1.5731059 ,
                           2.8372087 , -1.0359031 ], dtype=float32), grad_sq_ema=DeviceArray([159251.2   ,  22437.486 , 149047.03  ,  80298.56  ,
                            1388.0195,   7270.5576,  77696.6   ,  31400.562 ,
       