In this notebook we implemented a jittable explicit differentiation method.

In [20]:
# %% Imports
from jax import numpy as jnp, random
import jax

from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim
from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL


from forward_solver import fixed_point_solver_explicit, fixed_point_solver_implicit
from SBL import SBL
from bayesian_regression import bayesian_regression


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Lets first create some fake input the neural network:

# test data

In [2]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [3]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)

In [4]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [5]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (jax.lax.stop_gradient(tau)))

# Non-jit with SBL

Let's first do a forward pass wihtout jit:

In [6]:
%%time
with jax.disable_jit():
    loss, coeffs, prior, metrics= SBL(fixed_point_solver_explicit, 
                                                        theta, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-4, 
                                                        max_iter=500)

CPU times: user 1.44 s, sys: 128 ms, total: 1.57 s
Wall time: 3.2 s


In [7]:
print(loss, coeffs, prior, metrics)

-234.16995 [[ 2.88100680e-03]
 [-5.47068566e-06]
 [ 5.07184528e-02]
 [-1.04297855e-04]
 [ 7.30849147e-01]
 [ 4.23768885e-04]
 [-2.34779727e-04]
 [ 5.77287283e-04]
 [ 2.41000613e-04]
 [ 5.97529579e-04]
 [ 5.58131258e-04]
 [-3.33002099e-04]] [1.2790446e+04 2.7565962e+03 5.8329128e+01 4.3158569e+03 1.4594892e+00
 1.0515136e+02 1.9514742e+02 4.4390738e+02 1.3002066e+02 1.2394662e+01
 2.1444576e+01 4.7285149e+01 2.1528165e+00] (50, DeviceArray(9.194642e-05, dtype=float32))


So it works without jit - with jit it doesnt work (yet). Let's first check if we can calculate the derivative w.r.t the loss:

In [8]:
dL_dtheta = jax.grad(lambda X: SBL(fixed_point_solver_explicit, 
                                    X, 
                                    dt, 
                                    prior_init=None, 
                                    hyper_prior=(alpha_prior, beta_prior), 
                                    tol=1e-4, 
                                    max_iter=500)[0])

dL_ddt = jax.grad(lambda y: SBL(fixed_point_solver_explicit, 
                                    theta, 
                                    y, 
                                    prior_init=None, 
                                    hyper_prior=(alpha_prior, beta_prior), 
                                    tol=1e-4, 
                                    max_iter=500)[0])

In [9]:
with jax.disable_jit():
    grad_theta_exp = dL_dtheta(theta)
    grad_dt_exp = dL_ddt(dt)

In [10]:
grad_theta_exp.shape

(1000, 12)

In [11]:
grad_dt_exp.shape

(1000, 1)

In [12]:
jnp.any(jnp.isnan(grad_theta_exp))

DeviceArray(False, dtype=bool)

In [13]:
jnp.any(jnp.isnan(grad_dt_exp))

DeviceArray(False, dtype=bool)

Now move on to the jit - 

# With jit

In [None]:
%%time
loss, coeffs, prior, metrics= SBL(fixed_point_solver_explicit, 
                                                    theta, 
                                                    dt, 
                                                    prior_init=None, 
                                                    hyper_prior=(alpha_prior, beta_prior), 
                                                    tol=1e-4, 
                                                    max_iter=500)

This fails at the conditional function, probably becauyse of the jax.lax.cond - lets fix that.

In [21]:
%%time
loss, coeffs, prior, metrics= SBL(fixed_point_solver_explicit, 
                                                    theta, 
                                                    dt, 
                                                    prior_init=None, 
                                                    hyper_prior=(alpha_prior, beta_prior), 
                                                    tol=1e-4, 
                                                    max_iter=500)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/3)>
The problem arose with the `bool` function. 
While tracing the function SBL at /home/gert-jan/Documents/modax/notebooks/differentiable_fwd_solver/SBL.py:81, this concrete value was not available in Python because it depends on the value of the arguments to SBL at /home/gert-jan/Documents/modax/notebooks/differentiable_fwd_solver/SBL.py:81 at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax._src.errors.ConcretizationTypeError)