In this notebook we try to get a forward solver for the SBL working and explicitly differentiate, and compare with the implicit.

In [19]:
# %% Imports
from jax import numpy as jnp, random
import jax

from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim
from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL


from forward_solver import fixed_point_solver_explicit, fixed_point_solver_implicit
from SBL import SBL


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Lets first create some fake input the neural network:

# test data

In [14]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [15]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)

In [16]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [17]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (jax.lax.stop_gradient(tau)))

# SBL

Let's first do a forward pass wihtout jit:

In [20]:
with jax.disable_jit():
    loss, coeffs, prior, metrics= SBL(fixed_point_solver_explicit, 
                                                        theta, 
                                                        dt, 
                                                        prior_init=None, 
                                                        hyper_prior=(alpha_prior, beta_prior), 
                                                        tol=1e-4, 
                                                        max_iter=500)

In [21]:
print(loss, coeffs, prior, metrics)

-234.16673 [[ 2.8835083e-03]
 [-7.3347251e-06]
 [ 5.0495405e-02]
 [-5.9845635e-05]
 [ 7.3062313e-01]
 [ 4.2678340e-04]
 [-2.3846346e-04]
 [ 5.8384158e-04]
 [ 2.2641293e-04]
 [ 5.9847737e-04]
 [ 5.5580394e-04]
 [-3.4260462e-04]] [1.2780464e+04 1.9792905e+03 5.8624344e+01 7.6146699e+03 1.4603415e+00
 1.0457022e+02 1.9202785e+02 4.3865234e+02 1.3856334e+02 1.2384184e+01
 2.1545572e+01 4.6018299e+01 2.1528099e+00] (300, 0)


So it works without jit - with jit it doesnt work (yet). Let's first check if we can calculate the derivative w.r.t the loss:

In [26]:
dL_dtheta = jax.grad(lambda X: SBL(fixed_point_solver_explicit, 
                                    X, 
                                    dt, 
                                    prior_init=None, 
                                    hyper_prior=(alpha_prior, beta_prior), 
                                    tol=1e-4, 
                                    max_iter=500)[2][0])

dL_ddt = jax.grad(lambda y: SBL(fixed_point_solver_explicit, 
                                    theta, 
                                    y, 
                                    prior_init=None, 
                                    hyper_prior=(alpha_prior, beta_prior), 
                                    tol=1e-4, 
                                    max_iter=500)[2][0])

In [27]:
with jax.disable_jit():
    grad_theta_exp = dL_dtheta(theta)
    grad_dt_exp = dL_ddt(dt)

In [29]:
grad_theta_exp

DeviceArray([[ 4.0315060e+01, -7.8661844e-02,  2.5311833e+02, ...,
               1.2050523e+00,  1.3337393e+00, -6.8289971e-01],
             [ 4.2319916e+01, -8.1662700e-02,  2.4128146e+02, ...,
               1.2297906e+00,  1.3528125e+00, -6.8435913e-01],
             [ 4.4433788e+01, -8.4824540e-02,  2.2870581e+02, ...,
               1.2557112e+00,  1.3726672e+00, -6.8576968e-01],
             ...,
             [-5.3714981e+00, -5.8753863e-03,  3.5168759e+02, ...,
               3.2359827e-01,  3.5291189e-01, -4.1627902e-01],
             [-3.9541931e+00, -8.0396160e-03,  3.4455737e+02, ...,
               3.4352410e-01,  3.7065768e-01, -4.1905183e-01],
             [-2.5585327e+00, -1.0170683e-02,  3.3754132e+02, ...,
               3.6315131e-01,  3.8813728e-01, -4.2178833e-01]],            dtype=float32)

In [25]:
grad_dt_exp.shape

(1000, 1)

In [37]:
jnp.any(jnp.isnan(grad_theta_exp))

DeviceArray(False, dtype=bool)

In [38]:
jnp.any(jnp.isnan(grad_dt_exp))

DeviceArray(False, dtype=bool)

No nan- seems fine.

So its not fast but we're getting results. Now let's run a pass using the implicit diff method and compare the results - they should be the same. 

In [39]:
with jax.disable_jit():
    loss, coeffs, prior, metrics= SBL(fixed_point_solver_implicit, 
                                        theta, 
                                        dt, 
                                        prior_init=None, 
                                        hyper_prior=(alpha_prior, beta_prior), 
                                        tol=1e-4, 
                                        max_iter=500)

In [40]:
print(loss, coeffs, prior, metrics)

-234.16995 [[ 2.8810082e-03]
 [-5.4706748e-06]
 [ 5.0718643e-02]
 [-1.0429989e-04]
 [ 7.3084915e-01]
 [ 4.2338760e-04]
 [-2.3527378e-04]
 [ 5.7718530e-04]
 [ 2.4089876e-04]
 [ 5.9699081e-04]
 [ 5.5704353e-04]
 [-3.3291994e-04]] [1.2790438e+04 2.7565962e+03 5.8328896e+01 4.3157676e+03 1.4594887e+00
 1.0524597e+02 1.9473764e+02 4.4398584e+02 1.3007578e+02 1.2405847e+01
 2.1486467e+01 4.7296772e+01 2.1528165e+00] (50, DeviceArray(9.250849e-05, dtype=float32))


Forward pass is the same as it should be

In [41]:
dL_dtheta_imp = jax.grad(lambda X: SBL(fixed_point_solver_implicit, 
                                                            X, 
                                                            dt, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-4, 
                                                            max_iter=500)[0])

dL_ddt_imp = jax.grad(lambda y: SBL(fixed_point_solver_implicit, 
                                                            theta, 
                                                            y, 
                                                            prior_init=None, 
                                                            hyper_prior=(alpha_prior, beta_prior), 
                                                            tol=1e-4, 
                                                            max_iter=500)[0])

In [42]:
with jax.disable_jit():
    grad_theta_imp = dL_dtheta_imp(theta)
    grad_dt_imp = dL_ddt_imp(dt)

That worked - now let's check if they're similar:

In [43]:
jnp.any(jnp.isnan(grad_theta_imp))

DeviceArray(True, dtype=bool)

In [44]:
jnp.max(jnp.abs(grad_theta_exp - grad_theta_imp))

DeviceArray(nan, dtype=float32)

In [45]:
jnp.max(jnp.abs(grad_dt_exp - grad_dt_imp))

DeviceArray(nan, dtype=float32)

Okay so as usual the implicit method gives a nan. The best approach for now is to get the jit version of the explicit method working and see that one - that would give us a) something working at least and b) something to compare the SBL with.