So we know for sure the problem is with alpha. Let's figure it out by creating some fake input and not doing the optimized way, but the standard back-propagating thorugh the while loop.

It seems that by using way more iterations in the forward pass I can fix it...

In [1]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver, fwd_solver, fwd_solver_simple


%load_ext autoreload
%autoreload 2

In [2]:
def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

Lets first create some fake input the neural network:

In [3]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [4]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)

In [5]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [6]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))

In [7]:
n_samples, n_features = theta_normed.shape
norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
prior_init = jnp.concatenate([jnp.ones((n_features,)), (1.0 / (jnp.var(dt) + 1e-7))[jnp.newaxis]], axis=0)
gram = jnp.dot(theta_normed.T, theta_normed)
XT_y = jnp.dot(theta_normed.T, dt)

tol = 1e-3
max_iter = 5000 # high to ensure we're at a minimum

Let's do the non-custom diff version first:

In [8]:
update_fn = lambda z: update(z, theta_normed, dt, gram, XT_y, alpha_prior, beta_prior)
prior_params, metrics = fwd_solver(update_fn, prior_init, norm_weight, tol=tol,max_iter=max_iter)

In [9]:
print(prior_params, metrics)

[3.8263279e+01 1.4361025e+03 1.5917896e+00 1.4652253e+03 3.5256773e-01
 1.3987609e+03 1.0430831e+03 1.2461372e+03 8.3053589e+02 1.3490372e+03
 8.2156012e+02 1.3356289e+03 4.2957764e+00] (DeviceArray(5000, dtype=int32), DeviceArray(51.810234, dtype=float32))


Quickly comparing with the custom version (should be the same since its a thin wrapper)

In [10]:
prior_params, metrics = fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [11]:
print(prior_params, metrics)

[3.8263279e+01 1.4361025e+03 1.5917896e+00 1.4652253e+03 3.5256773e-01
 1.3987609e+03 1.0430831e+03 1.2461372e+03 8.3053589e+02 1.3490372e+03
 8.2156012e+02 1.3356289e+03 4.2957764e+00] (DeviceArray(5000, dtype=int32), DeviceArray(51.810234, dtype=float32))


Okay that's fine, as expected. Now to check the grad

In [12]:
direct_grad = jax.grad(lambda X: fwd_solver(lambda z: update(z, X, dt, gram, XT_y, alpha_prior, beta_prior), prior_init, norm_weight, tol=tol,max_iter=max_iter)[0][0])

In [13]:
direct_grad(theta_normed)

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

So we cant backpropagate through the jax while loop and scan doesn't work for us; maybe a nice first thing to fix for me.

In [14]:
update_fn = lambda z: update(z, theta_normed, dt, gram, XT_y, alpha_prior, beta_prior)
prior_params, metrics = fwd_solver_simple(update_fn, prior_init, norm_weight, tol=tol,max_iter=max_iter)

In [15]:
prior_params

Buffer([3.8263054e+01, 1.4373539e+03, 1.5917894e+00, 1.4648746e+03,
        3.5256758e-01, 1.3950604e+03, 1.0434330e+03, 1.2461440e+03,
        8.3029584e+02, 1.3491497e+03, 8.2146710e+02, 1.3354734e+03,
        4.2957764e+00], dtype=float32)

In [16]:
metrics

(Buffer(5000, dtype=int32), Buffer(47.584972, dtype=float32))

Result from forward pass is the same, now let's try to backpropagate:

In [17]:
direct_grad = jax.grad(lambda X: fwd_solver_simple(lambda z: update(z, X, dt, gram, XT_y, alpha_prior, beta_prior), prior_init, norm_weight, tol=tol,max_iter=max_iter)[0][0])

In [18]:
direct_grad(theta_normed)

DeviceArray([[-3.85704334e-04,  1.27980854e-07, -5.24493447e-03, ...,
              -3.92538277e-06, -5.20584126e-06,  3.69631675e-06],
             [-3.57993034e-04,  1.18785984e-07, -4.86810878e-03, ...,
              -3.64335756e-06, -4.83182157e-06,  3.43075317e-06],
             [-3.28540336e-04,  1.09013293e-07, -4.46760282e-03, ...,
              -3.34361425e-06, -4.43430190e-06,  3.14849990e-06],
             ...,
             [-5.86798997e-04,  1.94706232e-07, -7.97948986e-03, ...,
              -5.97195231e-06, -7.92000537e-06,  5.62346486e-06],
             [-5.70365286e-04,  1.89253313e-07, -7.75601715e-03, ...,
              -5.80471124e-06, -7.69820326e-06,  5.46597767e-06],
             [-5.54216327e-04,  1.83894926e-07, -7.53641920e-03, ...,
              -5.64036327e-06, -7.48024422e-06,  5.31121987e-06]],            dtype=float32)

In [22]:
fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

(Buffer([3.8263279e+01, 1.4361025e+03, 1.5917896e+00, 1.4652253e+03,
         3.5256773e-01, 1.3987609e+03, 1.0430831e+03, 1.2461372e+03,
         8.3053589e+02, 1.3490372e+03, 8.2156012e+02, 1.3356289e+03,
         4.2957764e+00], dtype=float32),
 (Buffer(5000, dtype=int32), Buffer(51.810234, dtype=float32)))

In [21]:
jax.grad(lambda X: fixed_point_solver(
    update,
    (X, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)[0][0])(theta_normed)

DeviceArray([[nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             ...,
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan],
             [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

SO we still get Nans. Maybe increase the max_iter?

In [24]:
max_iter = 1e5
fixed_point_solver(
    update,
    (theta_normed, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

(Buffer([3.8262684e+01, 1.4509980e+03, 1.5917910e+00, 1.4644015e+03,
         3.5256791e-01, 1.3952686e+03, 1.0433105e+03, 1.2459473e+03,
         8.3051746e+02, 1.3494642e+03, 8.2125189e+02, 1.3354650e+03,
         4.2957759e+00], dtype=float32),
 (Buffer(100000, dtype=int32), Buffer(22.624489, dtype=float32)))

In [25]:
jax.grad(lambda X: fixed_point_solver(
    update,
    (X, dt, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)[0][0])(theta_normed)

Buffer([[ 4.8051970e+06, -1.6028442e+03,  6.5343312e+07, ...,
          4.8894770e+04,  6.4869398e+04, -4.6046195e+04],
        [ 4.4599630e+06, -1.4876863e+03,  6.0648656e+07, ...,
          4.5381875e+04,  6.0208793e+04, -4.2737961e+04],
        [ 4.0930350e+06, -1.3652920e+03,  5.5658996e+07, ...,
          4.1648238e+04,  5.5255320e+04, -3.9221840e+04],
        ...,
        [ 7.3104715e+06, -2.4385154e+03,  9.9411200e+07, ...,
          7.4386922e+04,  9.8690203e+04, -7.0053188e+04],
        [ 7.1057390e+06, -2.3702236e+03,  9.6627152e+07, ...,
          7.2303688e+04,  9.5926344e+04, -6.8091320e+04],
        [ 6.9045555e+06, -2.3031160e+03,  9.3891360e+07, ...,
          7.0256562e+04,  9.3210398e+04, -6.6163469e+04]], dtype=float32)

That seems a bit high - so definitely somehting is wrong. Let's leave it for now and focus on the evidence, since we don't even have to backpropagate initially