We've build an implementation, now to make sure the SBL function works and the training as well.

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.training.losses.SBL import loss_fn_SBL
from modax.linear_model.SBL import SBL

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'svg'

# Making data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

state, params = variables.pop("params")

prediction, dt, theta, coeffs = model.apply({"params": params, **state}, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))
hyper_prior = (alpha_prior, beta_prior)

# testing function

In [6]:
loss, coeffs, prior, metrics = SBL(theta_normed, dt, None, hyper_prior, tol=1e-5, max_iter=1000)

In [7]:
print(metrics)

(DeviceArray(472, dtype=int32), DeviceArray(1.0035904e-05, dtype=float32))


Now let's try with a prior init:

In [8]:
SBL(theta_normed, dt, prior, hyper_prior, tol=1e-5, max_iter=1000)

(DeviceArray(455.70856, dtype=float32),
 DeviceArray([[ 4.8159838e-02],
              [-1.7831131e-05],
              [ 6.5489697e-01],
              [-8.3192805e-05],
              [ 1.6062738e+00],
              [ 6.3539577e-05],
              [-6.0119817e-04],
              [ 7.5133808e-04],
              [-3.0716820e-04],
              [ 4.9009838e-04],
              [ 6.5006682e-04],
              [-4.6149743e-04]], dtype=float32),
 DeviceArray([3.8262371e+01, 1.3062083e+03, 1.5917872e+00, 1.4582084e+03,
              3.5256827e-01, 1.4764838e+03, 1.0433547e+03, 1.2462081e+03,
              8.3055188e+02, 1.3493473e+03, 8.2135358e+02, 1.3354717e+03,
              4.2957764e+00], dtype=float32),
 (DeviceArray(1, dtype=int32), DeviceArray(1., dtype=float32)))

In [9]:
test = lambda x: SBL(x, dt, prior, hyper_prior, tol=1e-5, max_iter=1000)[0]

In [10]:
test(theta_normed)

DeviceArray(455.70856, dtype=float32)

In [11]:
jax.grad(lambda x: SBL(x, dt, prior, hyper_prior, tol=1e-5, max_iter=1000)[0])(theta_normed)

DeviceArray([[-1.2440701e-03,  3.8723842e-06,  2.5634898e-02, ...,
               1.9643254e-05,  1.8920218e-05, -2.0201613e-05],
             [-1.3898911e-03,  4.0267396e-06,  2.3853581e-02, ...,
               1.8187544e-05,  1.7692948e-05, -1.8726023e-05],
             [-1.5444972e-03,  4.1821468e-06,  2.1956580e-02, ...,
               1.6643622e-05,  1.6365988e-05, -1.7164115e-05],
             ...,
             [ 5.0725369e-04,  4.6074656e-06,  3.1805620e-02, ...,
               3.5687837e-05, -2.1029897e-05, -4.3647586e-05],
             [ 4.1563646e-04,  4.5651877e-06,  3.0797089e-02, ...,
               3.4767811e-05, -2.1377313e-05, -4.2674856e-05],
             [ 3.2546185e-04,  4.5169600e-06,  2.9804943e-02, ...,
               3.3862052e-05, -2.1718381e-05, -4.1717405e-05]],            dtype=float32)

# testing training

In [12]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, True))

In [13]:
train_max_iter(update_fn, optimizer, state, 5000)

Loss step 0: 577.9456787109375
Loss step 100: -463.77923583984375
Loss step 200: -1175.1824951171875
Loss step 300: -2273.142578125
Loss step 400: -2774.034423828125
Loss step 500: -3403.81982421875
Loss step 600: -3751.59912109375
Loss step 700: -4754.9013671875
Loss step 800: -6389.67236328125
Loss step 900: -6642.81005859375
Loss step 1000: -6722.11767578125
Loss step 1100: -6735.8134765625
Loss step 1200: -6743.01611328125
Loss step 1300: -6745.375
Loss step 1400: -6747.2119140625
Loss step 1500: -6759.123046875
Loss step 1600: -6755.7216796875
Loss step 1700: -6745.0556640625
Loss step 1800: -6749.6240234375
Loss step 1900: -6761.83837890625
Loss step 2000: -6762.3427734375
Loss step 2100: -6754.748046875
Loss step 2200: -6758.02734375
Loss step 2300: -6759.517578125
Loss step 2400: -6749.99755859375
Loss step 2500: -6767.84716796875
Loss step 2600: -6756.38671875
Loss step 2700: -6757.427734375
Loss step 2800: -6766.6083984375
Loss step 2900: -6764.9697265625
Loss step 3000: -675

KeyboardInterrupt: 