In this notebook we try to get a forward solver for the bayesian regression working and explicitly differentiate, and compare with the implicit.

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax

from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim
from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL


from forward_solver import fixed_point_solver_explicit, fixed_point_solver_implicit
from bayesian_regression import bayesian_regression


%load_ext autoreload
%autoreload 2

Lets first create some fake input the neural network:

# test data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (jax.lax.stop_gradient(tau)))

# Bayesian regression

Let's first do a forward pass wihtout jit:

In [6]:
with jax.disable_jit():
    loss, coeffs, prior, metrics= bayesian_regression(fixed_point_solver_explicit, 
                                                        theta_normed, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-5, 
                                                        max_iter=500)

In [7]:
print(loss, coeffs, prior, metrics)

-1154.081 [[ 4.1776318e-03]
 [-1.9126892e-06]
 [ 6.5167336e-04]
 [-1.0187295e-04]
 [ 5.8338847e-03]
 [ 9.4901613e-04]
 [-2.9233226e-03]
 [ 1.2994788e-04]
 [ 4.1545983e-03]
 [ 1.2195628e-03]
 [-2.3457576e-03]
 [-1.8430690e-03]] [529.0887    2.14874] (251, 0.0)


So it works without jit - with jit it doesnt work (yet). Let's first check if we can calculate the derivative w.r.t the loss:

In [8]:
dL_dtheta = jax.grad(lambda X: bayesian_regression(fixed_point_solver_explicit, 
                                                            X, 
                                                            dt, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-5, 
                                                            max_iter=500)[0])

dL_ddt = jax.grad(lambda y: bayesian_regression(fixed_point_solver_explicit, 
                                                            theta_normed, 
                                                            y, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-5, 
                                                            max_iter=500)[0])

In [9]:
with jax.disable_jit():
    grad_theta_exp = dL_dtheta(theta_normed)
    grad_dt_exp = dL_ddt(dt)

In [10]:
grad_theta_exp.shape

(1000, 12)

In [11]:
grad_dt_exp.shape

(1000, 1)

So its not fast but we're getting results. Now let's run a pass using the implicit diff method and compare the results - they should be the same. 

In [12]:
with jax.disable_jit():
    loss, coeffs, prior, metrics= bayesian_regression(fixed_point_solver_implicit, 
                                                        theta_normed, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-5, 
                                                        max_iter=500)

In [13]:
print(loss, coeffs, prior, metrics)

-1154.081 [[ 4.1776318e-03]
 [-1.9126892e-06]
 [ 6.5167336e-04]
 [-1.0187295e-04]
 [ 5.8338847e-03]
 [ 9.4901613e-04]
 [-2.9233226e-03]
 [ 1.2994788e-04]
 [ 4.1545983e-03]
 [ 1.2195628e-03]
 [-2.3457576e-03]
 [-1.8430690e-03]] [529.0887    2.14874] (251, 0.0)


Forward pass is the same as it should be

In [14]:
dL_dtheta_imp = jax.grad(lambda X: bayesian_regression(fixed_point_solver_implicit, 
                                                            X, 
                                                            dt, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-5, 
                                                            max_iter=500)[0])

dL_ddt_imp = jax.grad(lambda y: bayesian_regression(fixed_point_solver_implicit, 
                                                            theta_normed, 
                                                            y, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-5, 
                                                            max_iter=500)[0])

In [15]:
with jax.disable_jit():
    grad_theta_imp = dL_dtheta_imp(theta_normed)
    grad_dt_imp = dL_ddt_imp(dt)

That worked - now let's check if they're similar:

In [16]:
jnp.max(jnp.abs(grad_theta_exp - grad_theta_imp))

DeviceArray(2.1466985e-07, dtype=float32)

In [17]:
jnp.max(jnp.abs(grad_dt_exp - grad_dt_imp))

DeviceArray(5.364418e-07, dtype=float32)

They are: that's good news. Let's check if we can run the implicit with jit:

In [18]:
loss, coeffs, prior, metrics= bayesian_regression(fixed_point_solver_implicit, 
                                                        theta_normed, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-5, 
                                                        max_iter=500)

In [19]:
dL_dtheta_imp = jax.grad(lambda X: bayesian_regression(fixed_point_solver_implicit, 
                                                        X, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-5, 
                                                        max_iter=500)[0])

dL_ddt_imp = jax.grad(lambda y: bayesian_regression(fixed_point_solver_implicit, 
                                                        theta_normed, 
                                                        y, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-5, 
                                                        max_iter=500)[0])

In [20]:
grad_theta_imp = dL_dtheta_imp(theta_normed)
grad_dt_imp = dL_ddt_imp(dt)

Let's see if its still the same:

In [24]:
jnp.max(jnp.abs(grad_dt_exp - grad_dt_imp))

DeviceArray(5.364418e-07, dtype=float32)

In [25]:
jnp.max(jnp.abs(grad_theta_exp - grad_theta_imp))

DeviceArray(1.671724e-07, dtype=float32)

great! so to recap, we now have an explicitly and implicitly differentiable forward solver which works on bayesian regression. The implicit differentiable solver works when jitted, the explicit not (yet). We now have two paths - we can make the explicit solver jittable, and start working on the SBL. SBL has more priority, so let's try that.